Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

fix: there is no errors when sql is incorrect #123

Merged
merged 2 commits into from
Mar 28, 2025
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
101 changes: 65 additions & 36 deletions client/session.go
Original file line number Diff line number Diff line change
Expand Up @@ -25,14 +25,15 @@ import (
"encoding/binary"
"errors"
"fmt"
"github.com/apache/iotdb-client-go/common"
"log"
"net"
"reflect"
"sort"
"strings"
"time"

"github.com/apache/iotdb-client-go/common"

"github.com/apache/iotdb-client-go/rpc"
"github.com/apache/thrift/lib/go/thrift"
)
Expand Down Expand Up @@ -103,7 +104,7 @@ func (s *Session) Open(enableRPCCompression bool, connectionTimeoutInMs int) err
ConnectTimeout: time.Duration(connectionTimeoutInMs) * time.Millisecond, // Use 0 for no timeout
})
// s.trans = thrift.NewTFramedTransport(s.trans) // deprecated
var tmp_conf = thrift.TConfiguration{MaxFrameSize: thrift.DEFAULT_MAX_FRAME_SIZE}
tmp_conf := thrift.TConfiguration{MaxFrameSize: thrift.DEFAULT_MAX_FRAME_SIZE}
s.trans = thrift.NewTFramedTransportConf(s.trans, &tmp_conf)
if !s.trans.IsOpen() {
err = s.trans.Open()
Expand All @@ -115,8 +116,10 @@ func (s *Session) Open(enableRPCCompression bool, connectionTimeoutInMs int) err
iprot := s.protocolFactory.GetProtocol(s.trans)
oprot := s.protocolFactory.GetProtocol(s.trans)
s.client = rpc.NewIClientRPCServiceClient(thrift.NewTStandardClient(iprot, oprot))
req := rpc.TSOpenSessionReq{ClientProtocol: rpc.TSProtocolVersion_IOTDB_SERVICE_PROTOCOL_V3, ZoneId: s.config.TimeZone, Username: s.config.UserName,
Password: &s.config.Password}
req := rpc.TSOpenSessionReq{
ClientProtocol: rpc.TSProtocolVersion_IOTDB_SERVICE_PROTOCOL_V3, ZoneId: s.config.TimeZone, Username: s.config.UserName,
Password: &s.config.Password,
}
req.Configuration = make(map[string]string)
req.Configuration["sql_dialect"] = s.config.sqlDialect
if s.config.Version == "" {
Expand Down Expand Up @@ -165,8 +168,10 @@ func (s *Session) OpenCluster(enableRPCCompression bool) error {
iprot := s.protocolFactory.GetProtocol(s.trans)
oprot := s.protocolFactory.GetProtocol(s.trans)
s.client = rpc.NewIClientRPCServiceClient(thrift.NewTStandardClient(iprot, oprot))
req := rpc.TSOpenSessionReq{ClientProtocol: rpc.TSProtocolVersion_IOTDB_SERVICE_PROTOCOL_V3, ZoneId: s.config.TimeZone, Username: s.config.UserName,
Password: &s.config.Password}
req := rpc.TSOpenSessionReq{
ClientProtocol: rpc.TSProtocolVersion_IOTDB_SERVICE_PROTOCOL_V3, ZoneId: s.config.TimeZone, Username: s.config.UserName,
Password: &s.config.Password,
}
req.Configuration = make(map[string]string)
req.Configuration["sql_dialect"] = s.config.sqlDialect
if s.config.Version == "" {
Expand Down Expand Up @@ -267,8 +272,10 @@ func (s *Session) DeleteStorageGroups(storageGroupIds ...string) (r *common.TSSt
*error: correctness of operation
*/
func (s *Session) CreateTimeseries(path string, dataType TSDataType, encoding TSEncoding, compressor TSCompressionType, attributes map[string]string, tags map[string]string) (r *common.TSStatus, err error) {
request := rpc.TSCreateTimeseriesReq{SessionId: s.sessionId, Path: path, DataType: int32(dataType), Encoding: int32(encoding),
Compressor: int32(compressor), Attributes: attributes, Tags: tags}
request := rpc.TSCreateTimeseriesReq{
SessionId: s.sessionId, Path: path, DataType: int32(dataType), Encoding: int32(encoding),
Compressor: int32(compressor), Attributes: attributes, Tags: tags,
}
status, err := s.client.CreateTimeseries(context.Background(), &request)
if err != nil && status == nil {
if s.reconnect() {
Expand Down Expand Up @@ -352,8 +359,10 @@ func (s *Session) CreateMultiTimeseries(paths []string, dataTypes []TSDataType,
destCompressions[i] = int32(e)
}

request := rpc.TSCreateMultiTimeseriesReq{SessionId: s.sessionId, Paths: paths, DataTypes: destTypes,
Encodings: destEncodings, Compressors: destCompressions}
request := rpc.TSCreateMultiTimeseriesReq{
SessionId: s.sessionId, Paths: paths, DataTypes: destTypes,
Encodings: destEncodings, Compressors: destCompressions,
}
r, err = s.client.CreateMultiTimeseries(context.Background(), &request)

if err != nil && r == nil {
Expand Down Expand Up @@ -415,8 +424,10 @@ func (s *Session) DeleteData(paths []string, startTime int64, endTime int64) (r
*error: correctness of operation
*/
func (s *Session) InsertStringRecord(deviceId string, measurements []string, values []string, timestamp int64) (r *common.TSStatus, err error) {
request := rpc.TSInsertStringRecordReq{SessionId: s.sessionId, PrefixPath: deviceId, Measurements: measurements,
Values: values, Timestamp: timestamp}
request := rpc.TSInsertStringRecordReq{
SessionId: s.sessionId, PrefixPath: deviceId, Measurements: measurements,
Values: values, Timestamp: timestamp,
}
r, err = s.client.InsertStringRecord(context.Background(), &request)
if err != nil && r == nil {
if s.reconnect() {
Expand All @@ -442,26 +453,33 @@ func (s *Session) SetTimeZone(timeZone string) (r *common.TSStatus, err error) {
return r, err
}

func (s *Session) ExecuteStatement(sql string) (*SessionDataSet, error) {
func (s *Session) ExecuteStatementWithContext(ctx context.Context, sql string) (*SessionDataSet, error) {
request := rpc.TSExecuteStatementReq{
SessionId: s.sessionId,
Statement: sql,
StatementId: s.requestStatementId,
FetchSize: &s.config.FetchSize,
}
resp, err := s.client.ExecuteStatement(context.Background(), &request)
resp, err := s.client.ExecuteStatement(ctx, &request)

if err != nil && resp == nil {
if s.reconnect() {
request.SessionId = s.sessionId
request.StatementId = s.requestStatementId
resp, err = s.client.ExecuteStatement(context.Background(), &request)
resp, err = s.client.ExecuteStatement(ctx, &request)
}
}
if statusErr := VerifySuccess(resp.Status); statusErr != nil {
return nil, statusErr
}

return s.genDataSet(sql, resp), err
}

func (s *Session) ExecuteStatement(sql string) (*SessionDataSet, error) {
return s.ExecuteStatementWithContext(context.Background(), sql)
}

func (s *Session) ExecuteNonQueryStatement(sql string) (r *common.TSStatus, err error) {
request := rpc.TSExecuteStatementReq{
SessionId: s.sessionId,
Expand Down Expand Up @@ -490,8 +508,10 @@ func (s *Session) changeDatabase(database string) {
}

func (s *Session) ExecuteQueryStatement(sql string, timeoutMs *int64) (*SessionDataSet, error) {
request := rpc.TSExecuteStatementReq{SessionId: s.sessionId, Statement: sql, StatementId: s.requestStatementId,
FetchSize: &s.config.FetchSize, Timeout: timeoutMs}
request := rpc.TSExecuteStatementReq{
SessionId: s.sessionId, Statement: sql, StatementId: s.requestStatementId,
FetchSize: &s.config.FetchSize, Timeout: timeoutMs,
}
if resp, err := s.client.ExecuteQueryStatement(context.Background(), &request); err == nil {
if statusErr := VerifySuccess(resp.Status); statusErr == nil {
return NewSessionDataSet(sql, resp.Columns, resp.DataTypeList, resp.ColumnNameIndexMap, *resp.QueryId, s.client, s.sessionId, resp.QueryDataSet, resp.IgnoreTimeStamp != nil && *resp.IgnoreTimeStamp, s.config.FetchSize, timeoutMs), err
Expand All @@ -515,10 +535,12 @@ func (s *Session) ExecuteQueryStatement(sql string, timeoutMs *int64) (*SessionD

func (s *Session) ExecuteAggregationQuery(paths []string, aggregations []common.TAggregationType,
startTime *int64, endTime *int64, interval *int64,
timeoutMs *int64) (*SessionDataSet, error) {

request := rpc.TSAggregationQueryReq{SessionId: s.sessionId, StatementId: s.requestStatementId, Paths: paths,
Aggregations: aggregations, StartTime: startTime, EndTime: endTime, Interval: interval, FetchSize: &s.config.FetchSize, Timeout: timeoutMs}
timeoutMs *int64,
) (*SessionDataSet, error) {
request := rpc.TSAggregationQueryReq{
SessionId: s.sessionId, StatementId: s.requestStatementId, Paths: paths,
Aggregations: aggregations, StartTime: startTime, EndTime: endTime, Interval: interval, FetchSize: &s.config.FetchSize, Timeout: timeoutMs,
}
if resp, err := s.client.ExecuteAggregationQuery(context.Background(), &request); err == nil {
if statusErr := VerifySuccess(resp.Status); statusErr == nil {
return NewSessionDataSet("", resp.Columns, resp.DataTypeList, resp.ColumnNameIndexMap, *resp.QueryId, s.client, s.sessionId, resp.QueryDataSet, resp.IgnoreTimeStamp != nil && *resp.IgnoreTimeStamp, s.config.FetchSize, timeoutMs), err
Expand All @@ -541,11 +563,13 @@ func (s *Session) ExecuteAggregationQuery(paths []string, aggregations []common.

func (s *Session) ExecuteAggregationQueryWithLegalNodes(paths []string, aggregations []common.TAggregationType,
startTime *int64, endTime *int64, interval *int64,
timeoutMs *int64, legalNodes *bool) (*SessionDataSet, error) {

request := rpc.TSAggregationQueryReq{SessionId: s.sessionId, StatementId: s.requestStatementId, Paths: paths,
timeoutMs *int64, legalNodes *bool,
) (*SessionDataSet, error) {
request := rpc.TSAggregationQueryReq{
SessionId: s.sessionId, StatementId: s.requestStatementId, Paths: paths,
Aggregations: aggregations, StartTime: startTime, EndTime: endTime, Interval: interval, FetchSize: &s.config.FetchSize,
Timeout: timeoutMs, LegalPathNodes: legalNodes}
Timeout: timeoutMs, LegalPathNodes: legalNodes,
}
if resp, err := s.client.ExecuteAggregationQuery(context.Background(), &request); err == nil {
if statusErr := VerifySuccess(resp.Status); statusErr == nil {
return NewSessionDataSet("", resp.Columns, resp.DataTypeList, resp.ColumnNameIndexMap, *resp.QueryId, s.client, s.sessionId, resp.QueryDataSet, resp.IgnoreTimeStamp != nil && *resp.IgnoreTimeStamp, s.config.FetchSize, timeoutMs), err
Expand All @@ -570,7 +594,8 @@ func (s *Session) genTSInsertRecordReq(deviceId string, time int64,
measurements []string,
types []TSDataType,
values []interface{},
isAligned bool) (*rpc.TSInsertRecordReq, error) {
isAligned bool,
) (*rpc.TSInsertRecordReq, error) {
request := &rpc.TSInsertRecordReq{}
request.SessionId = s.sessionId
request.PrefixPath = deviceId
Expand Down Expand Up @@ -709,7 +734,7 @@ func (s *Session) InsertAlignedRecordsOfOneDevice(deviceId string, timestamps []
return nil, err
}
}
var isAligned = true
isAligned := true
request := &rpc.TSInsertRecordsOfOneDeviceReq{
SessionId: s.sessionId,
PrefixPath: deviceId,
Expand Down Expand Up @@ -744,7 +769,8 @@ func (s *Session) InsertAlignedRecordsOfOneDevice(deviceId string, timestamps []
*
*/
func (s *Session) InsertRecords(deviceIds []string, measurements [][]string, dataTypes [][]TSDataType, values [][]interface{},
timestamps []int64) (r *common.TSStatus, err error) {
timestamps []int64,
) (r *common.TSStatus, err error) {
request, err := s.genInsertRecordsReq(deviceIds, measurements, dataTypes, values, timestamps, false)
if err != nil {
return nil, err
Expand All @@ -761,7 +787,8 @@ func (s *Session) InsertRecords(deviceIds []string, measurements [][]string, dat
}

func (s *Session) InsertAlignedRecords(deviceIds []string, measurements [][]string, dataTypes [][]TSDataType, values [][]interface{},
timestamps []int64) (r *common.TSStatus, err error) {
timestamps []int64,
) (r *common.TSStatus, err error) {
request, err := s.genInsertRecordsReq(deviceIds, measurements, dataTypes, values, timestamps, true)
if err != nil {
return nil, err
Expand Down Expand Up @@ -932,7 +959,8 @@ func (s *Session) genInsertTabletsReq(tablets []*Tablet, isAligned bool) (*rpc.T
}

func (s *Session) genInsertRecordsReq(deviceIds []string, measurements [][]string, dataTypes [][]TSDataType, values [][]interface{},
timestamps []int64, isAligned bool) (*rpc.TSInsertRecordsReq, error) {
timestamps []int64, isAligned bool,
) (*rpc.TSInsertRecordsReq, error) {
length := len(deviceIds)
if length != len(timestamps) || length != len(measurements) || length != len(values) {
return nil, errLength
Expand Down Expand Up @@ -1169,7 +1197,7 @@ func newClusterSessionWithSqlDialect(clusterConfig *ClusterConfig) (Session, err
ConnectTimeout: time.Duration(0), // Use 0 for no timeout
})
// session.trans = thrift.NewTFramedTransport(session.trans) // deprecated
var tmp_conf = thrift.TConfiguration{MaxFrameSize: thrift.DEFAULT_MAX_FRAME_SIZE}
tmp_conf := thrift.TConfiguration{MaxFrameSize: thrift.DEFAULT_MAX_FRAME_SIZE}
session.trans = thrift.NewTFramedTransportConf(session.trans, &tmp_conf)
if !session.trans.IsOpen() {
err = session.trans.Open()
Expand All @@ -1196,7 +1224,7 @@ func (s *Session) initClusterConn(node endPoint) error {
})
if err == nil {
// s.trans = thrift.NewTFramedTransport(s.trans) // deprecated
var tmp_conf = thrift.TConfiguration{MaxFrameSize: thrift.DEFAULT_MAX_FRAME_SIZE}
tmp_conf := thrift.TConfiguration{MaxFrameSize: thrift.DEFAULT_MAX_FRAME_SIZE}
s.trans = thrift.NewTFramedTransportConf(s.trans, &tmp_conf)
if !s.trans.IsOpen() {
err = s.trans.Open()
Expand All @@ -1221,8 +1249,10 @@ func (s *Session) initClusterConn(node endPoint) error {
iprot := s.protocolFactory.GetProtocol(s.trans)
oprot := s.protocolFactory.GetProtocol(s.trans)
s.client = rpc.NewIClientRPCServiceClient(thrift.NewTStandardClient(iprot, oprot))
req := rpc.TSOpenSessionReq{ClientProtocol: rpc.TSProtocolVersion_IOTDB_SERVICE_PROTOCOL_V3, ZoneId: s.config.TimeZone, Username: s.config.UserName,
Password: &s.config.Password}
req := rpc.TSOpenSessionReq{
ClientProtocol: rpc.TSProtocolVersion_IOTDB_SERVICE_PROTOCOL_V3, ZoneId: s.config.TimeZone, Username: s.config.UserName,
Password: &s.config.Password,
}

resp, err := s.client.OpenSession(context.Background(), &req)
if err != nil {
Expand All @@ -1231,7 +1261,6 @@ func (s *Session) initClusterConn(node endPoint) error {
s.sessionId = resp.GetSessionId()
s.requestStatementId, err = s.client.RequestStatementId(context.Background(), s.sessionId)
return err

}

func getConfig(host string, port string, userName string, passWord string, fetchSize int32, timeZone string, connectRetryMax int, database string, sqlDialect string) *Config {
Expand All @@ -1250,7 +1279,7 @@ func getConfig(host string, port string, userName string, passWord string, fetch

func (s *Session) reconnect() bool {
var err error
var connectedSuccess = false
connectedSuccess := false

for i := 0; i < s.config.ConnectRetryMax; i++ {
for i := range s.endPointList {
Expand Down
10 changes: 9 additions & 1 deletion test/e2e/e2e_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -20,14 +20,16 @@
package e2e

import (
"context"
"fmt"
"github.com/apache/iotdb-client-go/common"
"log"
"math/rand"
"strings"
"testing"
"time"

"github.com/apache/iotdb-client-go/common"

"github.com/apache/iotdb-client-go/client"
"github.com/stretchr/testify/suite"
)
Expand Down Expand Up @@ -392,3 +394,9 @@ func (s *e2eTestSuite) Test_InsertAlignedTablets() {
assert.Equal(status, "8")
s.session.DeleteStorageGroup("root.ln.**")
}

func (s *e2eTestSuite) Test_InvalidSQL() {
_, err := s.session.ExecuteStatementWithContext(context.Background(), "select1 from device")
assert := s.Require()
assert.Error(err)
}