Skip to content

Commit bcfe634

Browse files
RDSK-9311: (base/board) Guard against nil responses in rdk client/servers (#4618)
1 parent 2b34a40 commit bcfe634

File tree

6 files changed

+194
-2
lines changed

6 files changed

+194
-2
lines changed

components/base/client_test.go

+7
Original file line numberDiff line numberDiff line change
@@ -75,6 +75,10 @@ func setupBrokenBase(brokenBase *inject.Base) {
7575
brokenBase.PropertiesFunc = func(ctx context.Context, extra map[string]interface{}) (base.Properties, error) {
7676
return base.Properties{}, errPropertiesFailed
7777
}
78+
79+
brokenBase.GeometriesFunc = func(ctx context.Context) ([]spatialmath.Geometry, error) {
80+
return nil, nil
81+
}
7882
}
7983

8084
func TestClient(t *testing.T) {
@@ -221,6 +225,9 @@ func TestClient(t *testing.T) {
221225
_, err = failingBaseClient.Properties(context.Background(), nil)
222226
test.That(t, err.Error(), test.ShouldContainSubstring, errPropertiesFailed.Error())
223227

228+
_, err = failingBaseClient.Geometries(context.Background(), nil)
229+
test.That(t, err.Error(), test.ShouldContainSubstring, base.ErrGeometriesNil(failBaseName).Error())
230+
224231
err = failingBaseClient.Stop(context.Background(), nil)
225232
test.That(t, err.Error(), test.ShouldContainSubstring, errStopFailed.Error())
226233

components/base/server.go

+10-1
Original file line numberDiff line numberDiff line change
@@ -1,8 +1,9 @@
1-
// Package base contains a gRPC based arm service server.
1+
// Package base contains a gRPC based base service server.
22
package base
33

44
import (
55
"context"
6+
"fmt"
67

78
commonpb "go.viam.com/api/common/v1"
89
pb "go.viam.com/api/component/base/v1"
@@ -13,6 +14,11 @@ import (
1314
"go.viam.com/rdk/spatialmath"
1415
)
1516

17+
// ErrGeometriesNil is the returned error if base geometries are nil.
18+
var ErrGeometriesNil = func(baseName string) error {
19+
return fmt.Errorf("base component %v Geometries should not return nil geometries", baseName)
20+
}
21+
1622
// serviceServer implements the BaseService from base.proto.
1723
type serviceServer struct {
1824
pb.UnimplementedBaseServiceServer
@@ -162,6 +168,9 @@ func (s *serviceServer) GetGeometries(ctx context.Context, req *commonpb.GetGeom
162168
if err != nil {
163169
return nil, err
164170
}
171+
if geometries == nil {
172+
return nil, ErrGeometriesNil(req.GetName())
173+
}
165174
return &commonpb.GetGeometriesResponse{Geometries: spatialmath.NewGeometriesToProto(geometries)}, nil
166175
}
167176

components/base/server_test.go

+32
Original file line numberDiff line numberDiff line change
@@ -4,12 +4,15 @@ import (
44
"context"
55
"testing"
66

7+
"github.com/golang/geo/r3"
78
"github.com/pkg/errors"
9+
pbcommon "go.viam.com/api/common/v1"
810
pb "go.viam.com/api/component/base/v1"
911
"go.viam.com/test"
1012

1113
"go.viam.com/rdk/components/base"
1214
"go.viam.com/rdk/resource"
15+
"go.viam.com/rdk/spatialmath"
1316
"go.viam.com/rdk/testutils/inject"
1417
)
1518

@@ -191,4 +194,33 @@ func TestServer(t *testing.T) {
191194
test.That(t, resp, test.ShouldBeNil)
192195
test.That(t, resource.IsNotFoundError(err), test.ShouldBeTrue)
193196
})
197+
198+
t.Run("Geometries", func(t *testing.T) {
199+
box, err := spatialmath.NewBox(
200+
spatialmath.NewPose(r3.Vector{X: 0, Y: 0, Z: 0}, spatialmath.NewZeroPose().Orientation()),
201+
r3.Vector{},
202+
testBaseName,
203+
)
204+
test.That(t, err, test.ShouldBeNil)
205+
206+
// on a successful get geometries
207+
workingBase.GeometriesFunc = func(ctx context.Context) ([]spatialmath.Geometry, error) {
208+
return []spatialmath.Geometry{box}, nil
209+
}
210+
req := &pbcommon.GetGeometriesRequest{Name: testBaseName}
211+
resp, err := server.GetGeometries(context.Background(), req) // TODO (rh) rename server to bServer after review
212+
test.That(t, resp, test.ShouldResemble, &pbcommon.GetGeometriesResponse{
213+
Geometries: spatialmath.NewGeometriesToProto([]spatialmath.Geometry{box}),
214+
})
215+
test.That(t, err, test.ShouldBeNil)
216+
217+
// on a failing get properties
218+
brokenBase.GeometriesFunc = func(ctx context.Context) ([]spatialmath.Geometry, error) {
219+
return nil, nil
220+
}
221+
req = &pbcommon.GetGeometriesRequest{Name: failBaseName}
222+
resp, err = server.GetGeometries(context.Background(), req)
223+
test.That(t, resp, test.ShouldBeNil)
224+
test.That(t, err, test.ShouldBeError, base.ErrGeometriesNil(failBaseName))
225+
})
194226
}

components/board/server.go

+44
Original file line numberDiff line numberDiff line change
@@ -3,6 +3,7 @@ package board
33

44
import (
55
"context"
6+
"fmt"
67

78
commonpb "go.viam.com/api/common/v1"
89
pb "go.viam.com/api/component/board/v1"
@@ -11,6 +12,21 @@ import (
1112
"go.viam.com/rdk/resource"
1213
)
1314

15+
var (
16+
// ErrGPIOPinByNameReturnNil is the error returned when a gpio pin is nil.
17+
ErrGPIOPinByNameReturnNil = func(boardName string) error {
18+
return fmt.Errorf("board component %v GPIOPinByName should not return nil pin", boardName)
19+
}
20+
// ErrAnalogByNameReturnNil is the error returned when an analog is nil.
21+
ErrAnalogByNameReturnNil = func(boardName string) error {
22+
return fmt.Errorf("board component %v AnalogByName should not return nil analog", boardName)
23+
}
24+
// ErrDigitalInterruptByNameReturnNil is the error returned when a digital interrupt is nil.
25+
ErrDigitalInterruptByNameReturnNil = func(boardName string) error {
26+
return fmt.Errorf("board component %v DigitalInterruptByName should not return nil digital interrupt", boardName)
27+
}
28+
)
29+
1430
// serviceServer implements the BoardService from board.proto.
1531
type serviceServer struct {
1632
pb.UnimplementedBoardServiceServer
@@ -34,6 +50,9 @@ func (s *serviceServer) SetGPIO(ctx context.Context, req *pb.SetGPIORequest) (*p
3450
if err != nil {
3551
return nil, err
3652
}
53+
if p == nil {
54+
return nil, ErrGPIOPinByNameReturnNil(req.Name)
55+
}
3756

3857
return &pb.SetGPIOResponse{}, p.Set(ctx, req.High, req.Extra.AsMap())
3958
}
@@ -49,6 +68,9 @@ func (s *serviceServer) GetGPIO(ctx context.Context, req *pb.GetGPIORequest) (*p
4968
if err != nil {
5069
return nil, err
5170
}
71+
if p == nil {
72+
return nil, ErrGPIOPinByNameReturnNil(req.Name)
73+
}
5274

5375
high, err := p.Get(ctx, req.Extra.AsMap())
5476
if err != nil {
@@ -68,6 +90,9 @@ func (s *serviceServer) PWM(ctx context.Context, req *pb.PWMRequest) (*pb.PWMRes
6890
if err != nil {
6991
return nil, err
7092
}
93+
if p == nil {
94+
return nil, ErrGPIOPinByNameReturnNil(req.Name)
95+
}
7196

7297
pwm, err := p.PWM(ctx, req.Extra.AsMap())
7398
if err != nil {
@@ -87,6 +112,9 @@ func (s *serviceServer) SetPWM(ctx context.Context, req *pb.SetPWMRequest) (*pb.
87112
if err != nil {
88113
return nil, err
89114
}
115+
if p == nil {
116+
return nil, ErrGPIOPinByNameReturnNil(req.Name)
117+
}
90118

91119
return &pb.SetPWMResponse{}, p.SetPWM(ctx, req.DutyCyclePct, req.Extra.AsMap())
92120
}
@@ -102,6 +130,9 @@ func (s *serviceServer) PWMFrequency(ctx context.Context, req *pb.PWMFrequencyRe
102130
if err != nil {
103131
return nil, err
104132
}
133+
if p == nil {
134+
return nil, ErrGPIOPinByNameReturnNil(req.Name)
135+
}
105136

106137
freq, err := p.PWMFreq(ctx, req.Extra.AsMap())
107138
if err != nil {
@@ -125,6 +156,9 @@ func (s *serviceServer) SetPWMFrequency(
125156
if err != nil {
126157
return nil, err
127158
}
159+
if p == nil {
160+
return nil, ErrGPIOPinByNameReturnNil(req.Name)
161+
}
128162

129163
return &pb.SetPWMFrequencyResponse{}, p.SetPWMFreq(ctx, uint(req.FrequencyHz), req.Extra.AsMap())
130164
}
@@ -143,6 +177,9 @@ func (s *serviceServer) ReadAnalogReader(
143177
if err != nil {
144178
return nil, err
145179
}
180+
if theReader == nil {
181+
return nil, ErrAnalogByNameReturnNil(req.BoardName)
182+
}
146183

147184
analogValue, err := theReader.Read(ctx, req.Extra.AsMap())
148185
if err != nil {
@@ -170,6 +207,9 @@ func (s *serviceServer) WriteAnalog(
170207
if err != nil {
171208
return nil, err
172209
}
210+
if analog == nil {
211+
return nil, ErrAnalogByNameReturnNil(req.Name)
212+
}
173213

174214
err = analog.Write(ctx, int(req.Value), req.Extra.AsMap())
175215
if err != nil {
@@ -193,6 +233,9 @@ func (s *serviceServer) GetDigitalInterruptValue(
193233
if err != nil {
194234
return nil, err
195235
}
236+
if interrupt == nil {
237+
return nil, ErrDigitalInterruptByNameReturnNil(req.BoardName)
238+
}
196239

197240
val, err := interrupt.Value(ctx, req.Extra.AsMap())
198241
if err != nil {
@@ -218,6 +261,7 @@ func (s *serviceServer) StreamTicks(
218261
if err != nil {
219262
return err
220263
}
264+
221265
interrupts = append(interrupts, di)
222266
}
223267
err = b.StreamTicks(server.Context(), interrupts, ticksChan, req.Extra.AsMap())

0 commit comments

Comments
 (0)