diff --git a/error.go b/error.go index a67527acd..51e2d3ee8 100644 --- a/error.go +++ b/error.go @@ -28,6 +28,10 @@ var ( ErrConditionType = errors.New("Unsupported condition type") // ErrUnSupportedSQLType parameter of SQL is not supported ErrUnSupportedSQLType = errors.New("unsupported sql type") + // ErrNoPrimaryKey represents an error lack of primary key + ErrNoPrimaryKey = errors.New("Current table has no necessary primary key") + // ErrMapKeyIsNotValid represents an error map key is not valid + ErrMapKeyIsNotValid = errors.New("Map key type must be a slice because the table have serval primary keys") ) // ErrFieldIsNotExist columns does not exist @@ -49,3 +53,18 @@ type ErrFieldIsNotValid struct { func (e ErrFieldIsNotValid) Error() string { return fmt.Sprintf("field %s is not valid on table %s", e.FieldName, e.TableName) } + +// ErrPrimaryKeyNoSelected represents an error primary key not selected +type ErrPrimaryKeyNoSelected struct { + PrimaryKey string +} + +func (e ErrPrimaryKeyNoSelected) Error() string { + return fmt.Sprintf("primary key %s is not selected", e.PrimaryKey) +} + +// IsErrPrimaryKeyNoSelected returns true is err is ErrPrimaryKeyNoSelected +func IsErrPrimaryKeyNoSelected(err error) bool { + _, ok := err.(ErrPrimaryKeyNoSelected) + return ok +} diff --git a/session_find.go b/session_find.go index 6b8aa469d..43c6d892f 100644 --- a/session_find.go +++ b/session_find.go @@ -250,10 +250,21 @@ func (session *Session) noCacheFind(table *core.Table, containerValue reflect.Va } else { keyType := containerValue.Type().Key() if len(table.PrimaryKeys) == 0 { - return errors.New("don't support multiple primary key's map has non-slice key type") + return ErrNoPrimaryKey } if len(table.PrimaryKeys) > 1 && keyType.Kind() != reflect.Slice { - return errors.New("don't support multiple primary key's map has non-slice key type") + return ErrMapKeyIsNotValid + } + + var found bool + for _, field := range fields { + if strings.EqualFold(field, table.PrimaryKeys[0]) { + found = true + break + } + } + if !found { + return ErrPrimaryKeyNoSelected{table.PrimaryKeys[0]} } containerValueSetFunc = func(newValue *reflect.Value, pk core.PK) error { diff --git a/session_find_test.go b/session_find_test.go index f805f06e0..e48343608 100644 --- a/session_find_test.go +++ b/session_find_test.go @@ -10,8 +10,8 @@ import ( "testing" "time" - "xorm.io/core" "github.com/stretchr/testify/assert" + "xorm.io/core" ) func TestJoinLimit(t *testing.T) { @@ -801,3 +801,26 @@ func TestFindJoin(t *testing.T) { Where("scene_item.type=?", 3).Or("device_user_privrels.user_id=?", 339).Find(&scenes) assert.NoError(t, err) } + +func TestFindMapCols(t *testing.T) { + type FindMapCols struct { + Id int64 + ColA string + ColB string + } + + assert.NoError(t, prepareEngine()) + assertSync(t, new(FindMapCols)) + + id := testEngine.GetColumnMapper().Obj2Table("Id") + colA := testEngine.GetColumnMapper().Obj2Table("ColA") + colB := testEngine.GetColumnMapper().Obj2Table("ColB") + + var objs = make(map[int64]*FindMapCols) + err := testEngine.Cols(colA, colB).Find(&objs) + assert.Error(t, err) + assert.True(t, IsErrPrimaryKeyNoSelected(err)) + + err = testEngine.Cols(id, colA, colB).Find(&objs) + assert.NoError(t, err) +}