Skip to content

Commit 12939c4

Browse files
authored
DATA-3462 - Refactor code for using in Cloud inference (viamrobotics#4792)
1 parent e4d85cd commit 12939c4

File tree

5 files changed

+372
-320
lines changed

5 files changed

+372
-320
lines changed

ml/detections.go

+342
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,342 @@
1+
package ml
2+
3+
import (
4+
"math"
5+
"strconv"
6+
"strings"
7+
"sync"
8+
9+
"github.com/pkg/errors"
10+
"gorgonia.org/tensor"
11+
12+
"go.viam.com/rdk/data"
13+
"go.viam.com/rdk/utils"
14+
)
15+
16+
const (
17+
detectorLocationName = "location"
18+
detectorCategoryName = "category"
19+
detectorScoreName = "score"
20+
)
21+
22+
// FormatDetectionOutputs formats the output tensors from a model into detections.
23+
func FormatDetectionOutputs(outNameMap *sync.Map, outMap Tensors, origW, origH int,
24+
boxOrder []int, labels []string,
25+
) ([]data.BoundingBox, error) {
26+
// use the outNameMap to find the tensor names, or guess and cache the names
27+
locationName, categoryName, scoreName, err := findDetectionTensorNames(outMap, outNameMap)
28+
if err != nil {
29+
return nil, err
30+
}
31+
locations, err := ConvertToFloat64Slice(outMap[locationName].Data())
32+
if err != nil {
33+
return nil, err
34+
}
35+
scores, err := ConvertToFloat64Slice(outMap[scoreName].Data())
36+
if err != nil {
37+
return nil, err
38+
}
39+
hasCategoryTensor := false
40+
categories := make([]float64, len(scores)) // default 0 category if no category output
41+
if categoryName != "" {
42+
hasCategoryTensor = true
43+
categories, err = ConvertToFloat64Slice(outMap[categoryName].Data())
44+
if err != nil {
45+
return nil, err
46+
}
47+
}
48+
// sometimes categories are stuffed into the score output. separate them out.
49+
if !hasCategoryTensor {
50+
shape := outMap[scoreName].Shape()
51+
if len(shape) == 3 { // cartegories are stored in 3rd dimension
52+
nCategories := shape[2] // nCategories usually in 3rd dim, but sometimes in 2nd
53+
if 4*nCategories == len(locations) { // it's actually in 2nd dim
54+
nCategories = shape[1]
55+
}
56+
scores, categories, err = extractCategoriesFromScores(scores, nCategories)
57+
if err != nil {
58+
return nil, errors.Wrap(err, "could not extract categories from score tensor")
59+
}
60+
}
61+
}
62+
63+
// Now reshape outMap into Detections
64+
if len(categories) != len(scores) || 4*len(scores) != len(locations) {
65+
return nil, errors.Errorf(
66+
"output tensor sizes did not match each other as expected. score: %v, category: %v, location: %v",
67+
len(scores),
68+
len(categories),
69+
len(locations),
70+
)
71+
}
72+
detections := make([]data.BoundingBox, 0, len(scores))
73+
detectionBoxesAreProportional := false
74+
for i := 0; i < len(scores); i++ {
75+
// heuristic for knowing if bounding box coordinates are abolute pixel locations, or
76+
// proportional pixel locations. Absolute bounding boxes will not usually be less than a pixel
77+
// and purely located in the upper left corner.
78+
if i == 0 && (locations[0]+locations[1]+locations[2]+locations[3] < 4.) {
79+
detectionBoxesAreProportional = true
80+
}
81+
var xmin, ymin, xmax, ymax float64
82+
if detectionBoxesAreProportional {
83+
xmin = utils.Clamp(locations[4*i+GetIndex(boxOrder, 0)], 0, 1)
84+
ymin = utils.Clamp(locations[4*i+GetIndex(boxOrder, 1)], 0, 1)
85+
xmax = utils.Clamp(locations[4*i+GetIndex(boxOrder, 2)], 0, 1)
86+
ymax = utils.Clamp(locations[4*i+GetIndex(boxOrder, 3)], 0, 1)
87+
} else {
88+
xmin = utils.Clamp(locations[4*i+GetIndex(boxOrder, 0)], 0, float64(origW-1)) / float64(origW-1)
89+
ymin = utils.Clamp(locations[4*i+GetIndex(boxOrder, 1)], 0, float64(origH-1)) / float64(origH-1)
90+
xmax = utils.Clamp(locations[4*i+GetIndex(boxOrder, 2)], 0, float64(origW-1)) / float64(origW-1)
91+
ymax = utils.Clamp(locations[4*i+GetIndex(boxOrder, 3)], 0, float64(origH-1)) / float64(origH-1)
92+
}
93+
labelNum := int(utils.Clamp(categories[i], 0, math.MaxInt))
94+
95+
if labels == nil {
96+
detections = append(detections, data.BoundingBox{
97+
Confidence: &scores[i],
98+
Label: strconv.Itoa(labelNum),
99+
XMinNormalized: xmin,
100+
YMinNormalized: ymin,
101+
XMaxNormalized: xmax,
102+
YMaxNormalized: ymax,
103+
})
104+
} else {
105+
if labelNum >= len(labels) {
106+
return nil, errors.Errorf("cannot access label number %v from label file with %v labels", labelNum, len(labels))
107+
}
108+
detections = append(detections, data.BoundingBox{
109+
Confidence: &scores[i],
110+
Label: labels[labelNum],
111+
XMinNormalized: xmin,
112+
YMinNormalized: ymin,
113+
XMaxNormalized: xmax,
114+
YMaxNormalized: ymax,
115+
})
116+
}
117+
}
118+
return detections, nil
119+
}
120+
121+
// findDetectionTensors finds the tensors that are necessary for object detection
122+
// the returned tensor order is location, category, score. It caches results.
123+
// category is optional, and will return "" if not present.
124+
func findDetectionTensorNames(outMap Tensors, nameMap *sync.Map) (string, string, string, error) {
125+
// first try the nameMap
126+
loc, okLoc := nameMap.Load(detectorLocationName)
127+
score, okScores := nameMap.Load(detectorScoreName)
128+
cat, okCat := nameMap.Load(detectorCategoryName)
129+
if okLoc && okCat && okScores { // names are known
130+
locString, ok := loc.(string)
131+
if !ok {
132+
return "", "", "", errors.Errorf("name map was not storing string, but a type %T", loc)
133+
}
134+
catString, ok := cat.(string)
135+
if !ok {
136+
return "", "", "", errors.Errorf("name map was not storing string, but a type %T", cat)
137+
}
138+
scoreString, ok := score.(string)
139+
if !ok {
140+
return "", "", "", errors.Errorf("name map was not storing string, but a type %T", score)
141+
}
142+
return locString, catString, scoreString, nil
143+
}
144+
if okLoc && okScores { // names are known, just no categories
145+
locString, ok := loc.(string)
146+
if !ok {
147+
return "", "", "", errors.Errorf("name map was not storing string, but a type %T", loc)
148+
}
149+
scoreString, ok := score.(string)
150+
if !ok {
151+
return "", "", "", errors.Errorf("name map was not storing string, but a type %T", score)
152+
}
153+
if len(outMap[scoreString].Shape()) == 3 || len(outMap) == 2 { // the categories are in the score
154+
return locString, "", scoreString, nil
155+
}
156+
}
157+
// next, if nameMap is not set, just see if the outMap has expected names
158+
// if the outMap only has two outputs, it might just be locations and scores.
159+
_, okLoc = outMap[detectorLocationName]
160+
_, okCat = outMap[detectorCategoryName]
161+
_, okScores = outMap[detectorScoreName]
162+
if okLoc && okCat && okScores { // names are as expected
163+
nameMap.Store(detectorLocationName, detectorLocationName)
164+
nameMap.Store(detectorCategoryName, detectorCategoryName)
165+
nameMap.Store(detectorScoreName, detectorScoreName)
166+
return detectorLocationName, detectorCategoryName, detectorScoreName, nil
167+
}
168+
// last, do a hack-y thing to try to guess the tensor names for the detection output tensors
169+
locationName, categoryName, scoreName, err := guessDetectionTensorNames(outMap)
170+
if err != nil {
171+
return "", "", "", err
172+
}
173+
nameMap.Store(detectorLocationName, locationName)
174+
nameMap.Store(detectorCategoryName, categoryName)
175+
nameMap.Store(detectorScoreName, scoreName)
176+
return locationName, categoryName, scoreName, nil
177+
}
178+
179+
// guessDetectionTensors is a hack-y function meant to find the correct detection tensors if the tensors
180+
// were not given the expected names, or have no metadata. This function should succeed
181+
// for models built with the viam platform.
182+
func guessDetectionTensorNames(outMap Tensors) (string, string, string, error) {
183+
foundTensor := map[string]bool{}
184+
mappedNames := map[string]string{}
185+
outNames := TensorNames(outMap)
186+
_, okLoc := outMap[detectorLocationName]
187+
if okLoc {
188+
foundTensor[detectorLocationName] = true
189+
mappedNames[detectorLocationName] = detectorLocationName
190+
}
191+
_, okCat := outMap[detectorCategoryName]
192+
if okCat {
193+
foundTensor[detectorCategoryName] = true
194+
mappedNames[detectorCategoryName] = detectorCategoryName
195+
}
196+
_, okScores := outMap[detectorScoreName]
197+
if okScores {
198+
foundTensor[detectorScoreName] = true
199+
mappedNames[detectorScoreName] = detectorScoreName
200+
}
201+
// first find how many detections there were
202+
// this will be used to find the other tensors
203+
nDetections := 0
204+
for name, t := range outMap {
205+
if _, alreadyFound := foundTensor[name]; alreadyFound {
206+
continue
207+
}
208+
if t.Dims() == 1 { // usually n-detections has its own tensor
209+
val, err := t.At(0)
210+
if err != nil {
211+
return "", "", "", err
212+
}
213+
val64, err := ConvertToFloat64Slice(val)
214+
if err != nil {
215+
return "", "", "", err
216+
}
217+
nDetections = int(val64[0])
218+
foundTensor[name] = true
219+
break
220+
}
221+
}
222+
if !okLoc { // guess the name of the location tensor
223+
// location tensor should have 3 dimensions usually
224+
for name, t := range outMap {
225+
if _, alreadyFound := foundTensor[name]; alreadyFound {
226+
continue
227+
}
228+
if t.Dims() == 3 {
229+
mappedNames[detectorLocationName] = name
230+
foundTensor[name] = true
231+
break
232+
}
233+
}
234+
if _, ok := mappedNames[detectorLocationName]; !ok {
235+
return "", "", "", errors.Errorf("could not find an output tensor named 'location' among [%s]", strings.Join(outNames, ", "))
236+
}
237+
}
238+
if !okCat { // guess the name of the category tensor
239+
// a category usually has a whole number in its elements, so either look for
240+
// int data types in the tensor, or sum the elements and make sure they dont have any decimals
241+
for name, t := range outMap {
242+
if _, alreadyFound := foundTensor[name]; alreadyFound {
243+
continue
244+
}
245+
dt := t.Dtype()
246+
if t.Dims() == 2 {
247+
if dt == tensor.Int || dt == tensor.Int32 || dt == tensor.Int64 ||
248+
dt == tensor.Uint32 || dt == tensor.Uint64 || dt == tensor.Int8 || dt == tensor.Uint8 {
249+
mappedNames[detectorCategoryName] = name
250+
foundTensor[name] = true
251+
break
252+
}
253+
// check if fully whole number
254+
var whole tensor.Tensor
255+
var err error
256+
if nDetections == 0 {
257+
whole, err = tensor.Sum(t)
258+
if err != nil {
259+
return "", "", "", err
260+
}
261+
} else {
262+
s, err := t.Slice(nil, tensor.S(0, nDetections))
263+
if err != nil {
264+
return "", "", "", err
265+
}
266+
whole, err = tensor.Sum(s)
267+
if err != nil {
268+
return "", "", "", err
269+
}
270+
}
271+
val, err := ConvertToFloat64Slice(whole.Data())
272+
if err != nil {
273+
return "", "", "", err
274+
}
275+
if math.Mod(val[0], 1) == 0 {
276+
mappedNames[detectorCategoryName] = name
277+
foundTensor[name] = true
278+
break
279+
}
280+
}
281+
}
282+
if _, ok := mappedNames[detectorCategoryName]; !ok {
283+
return "", "", "", errors.Errorf("could not find an output tensor named 'category' among [%s]", strings.Join(outNames, ", "))
284+
}
285+
}
286+
if !okScores { // guess the name of the scores tensor
287+
// a score usually has a float data type
288+
for name, t := range outMap {
289+
if _, alreadyFound := foundTensor[name]; alreadyFound {
290+
continue
291+
}
292+
dt := t.Dtype()
293+
if t.Dims() == 2 && (dt == tensor.Float32 || dt == tensor.Float64) {
294+
mappedNames[detectorScoreName] = name
295+
foundTensor[name] = true
296+
break
297+
}
298+
}
299+
if _, ok := mappedNames[detectorScoreName]; !ok {
300+
return "", "", "", errors.Errorf("could not find an output tensor named 'score' among [%s]", strings.Join(outNames, ", "))
301+
}
302+
}
303+
return mappedNames[detectorLocationName], mappedNames[detectorCategoryName], mappedNames[detectorScoreName], nil
304+
}
305+
306+
func extractCategoriesFromScores(scores []float64, nCategories int) ([]float64, []float64, error) {
307+
if nCategories == 1 { // trivially every category has the same label
308+
categories := make([]float64, len(scores))
309+
return scores, categories, nil
310+
}
311+
// ensure even division of data into categories
312+
if len(scores)%nCategories != 0 {
313+
return nil, nil, errors.Errorf("nCategories %v does not divide evenly into score tensor of length %v", nCategories, len(scores))
314+
}
315+
nEntries := len(scores) / nCategories
316+
newCategories := make([]float64, 0, nEntries)
317+
newScores := make([]float64, 0, nEntries)
318+
for i := 0; i < nEntries; i++ {
319+
argMax, floatMax, err := argMaxAndMax(scores[nCategories*i : nCategories*i+nCategories])
320+
if err != nil {
321+
return nil, nil, err
322+
}
323+
newCategories = append(newCategories, float64(argMax))
324+
newScores = append(newScores, floatMax)
325+
}
326+
return newScores, newCategories, nil
327+
}
328+
329+
func argMaxAndMax(slice []float64) (int, float64, error) {
330+
if len(slice) == 0 {
331+
return 0, 0.0, errors.New("slice cannot be nil or empty")
332+
}
333+
argMax := 0
334+
floatMax := -math.MaxFloat64
335+
for i, v := range slice {
336+
if v > floatMax {
337+
floatMax = v
338+
argMax = i
339+
}
340+
}
341+
return argMax, floatMax, nil
342+
}

ml/ml.go

+11
Original file line numberDiff line numberDiff line change
@@ -215,3 +215,14 @@ func TensorNames(t Tensors) []string {
215215
}
216216
return names
217217
}
218+
219+
// GetIndex returns the index of an int in an array of ints
220+
// Will return -1 if it's not there.
221+
func GetIndex(s []int, num int) int {
222+
for i, v := range s {
223+
if v == num {
224+
return i
225+
}
226+
}
227+
return -1
228+
}

services/vision/mlvision/classifier.go

+1-1
Original file line numberDiff line numberDiff line change
@@ -40,7 +40,7 @@ func attemptToBuildClassifier(mlm mlmodel.Service,
4040
return nil, errors.Errorf("invalid length of shape array (expected 4, got %d)", shapeLen)
4141
}
4242
channelsFirst := false // if channelFirst is true, then shape is (1, 3, height, width)
43-
if shape := md.Inputs[0].Shape; getIndex(shape, 3) == 1 {
43+
if shape := md.Inputs[0].Shape; ml.GetIndex(shape, 3) == 1 {
4444
channelsFirst = true
4545
inHeight, inWidth = shape[2], shape[3]
4646
} else {

0 commit comments

Comments
 (0)