Skip to content
This repository was archived by the owner on Dec 21, 2023. It is now read-only.

Commit 038db92

Browse files
authored
Refactor Object Detection inference to use new Model Trainer type (#3034)
1 parent 5393ebb commit 038db92

12 files changed

+584
-191
lines changed

src/ml/neural_net/CMakeLists.txt

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -31,6 +31,7 @@ if(APPLE AND HAS_MPS AND NOT TC_BUILD_IOS)
3131
mps_weight.mm
3232
mps_device_manager.m
3333
mps_descriptor_utils.m
34+
mps_od_backend.mm
3435
style_transfer/mps_style_transfer.m
3536
style_transfer/mps_style_transfer_backend.mm
3637
style_transfer/mps_style_transfer_utils.m

src/ml/neural_net/mps_compute_context.hpp

Lines changed: 1 addition & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -62,8 +62,7 @@ class mps_compute_context: public compute_context {
6262
std::function<float(float lower, float upper)> rng);
6363

6464
private:
65-
66-
std::unique_ptr<mps_command_queue> command_queue_;
65+
std::shared_ptr<mps_command_queue> command_queue_;
6766
};
6867

6968
} // namespace neural_net

src/ml/neural_net/mps_compute_context.mm

Lines changed: 17 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -10,11 +10,13 @@
1010

1111
#include <core/logging/logger.hpp>
1212
#include <core/storage/fileio/fileio_constants.hpp>
13+
#include <ml/neural_net/mps_image_augmentation.hpp>
14+
#include <ml/neural_net/mps_od_backend.hpp>
15+
#include <ml/neural_net/style_transfer/mps_style_transfer_backend.hpp>
16+
1317
#include <ml/neural_net/mps_cnnmodule.h>
1418
#include <ml/neural_net/mps_graph_cnnmodule.h>
15-
#include <ml/neural_net/mps_image_augmentation.hpp>
1619

17-
#import <ml/neural_net/style_transfer/mps_style_transfer_backend.hpp>
1820

1921
namespace turi {
2022
namespace neural_net {
@@ -125,18 +127,23 @@ float_array_map multiply_mps_od_loss_multiplier(float_array_map config,
125127
std::unique_ptr<model_backend> mps_compute_context::create_object_detector(
126128
int n, int c_in, int h_in, int w_in, int c_out, int h_out, int w_out,
127129
const float_array_map& config, const float_array_map& weights) {
128-
float_array_map updated_config;
130+
mps_od_backend::parameters params;
131+
params.command_queue = command_queue_;
132+
params.n = n;
133+
params.c_in = c_in;
134+
params.h_in = h_in;
135+
params.w_in = w_in;
136+
params.c_out = c_out;
137+
params.h_out = h_out;
138+
params.w_out = w_out;
139+
params.weights = weights;
140+
129141
std::vector<std::string> update_keys = {
130142
"learning_rate", "od_scale_class", "od_scale_no_object", "od_scale_object",
131143
"od_scale_wh", "od_scale_xy", "gradient_clipping"};
132-
updated_config = multiply_mps_od_loss_multiplier(config, update_keys);
133-
std::unique_ptr<mps_graph_cnn_module> result(
134-
new mps_graph_cnn_module(*command_queue_));
135-
136-
result->init(/* network_id */ kODGraphNet, n, c_in, h_in, w_in, c_out, h_out,
137-
w_out, updated_config, weights);
144+
params.config = multiply_mps_od_loss_multiplier(config, update_keys);
138145

139-
return result;
146+
return std::unique_ptr<mps_od_backend>(new mps_od_backend(std::move(params)));
140147
}
141148

142149
std::unique_ptr<model_backend> mps_compute_context::create_activity_classifier(

src/ml/neural_net/mps_od_backend.hpp

Lines changed: 63 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,63 @@
1+
/* Copyright © 2020 Apple Inc. All rights reserved.
2+
*
3+
* Use of this source code is governed by a BSD-3-clause license that can
4+
* be found in the LICENSE.txt file or at
5+
* https://opensource.org/licenses/BSD-3-Clause
6+
*/
7+
8+
#ifndef MPS_OD_BACKEND_HPP_
9+
#define MPS_OD_BACKEND_HPP_
10+
11+
#include <ml/neural_net/mps_graph_cnnmodule.h>
12+
#include <ml/neural_net/model_backend.hpp>
13+
14+
namespace turi {
15+
namespace neural_net {
16+
17+
/**
18+
* Model backend for object detection that uses a separate mps_graph_cnnmodule
19+
* for training and for inference, since mps_graph_cnnmodule doesn't currently
20+
* support doing both.
21+
*/
22+
class mps_od_backend : public model_backend {
23+
public:
24+
struct parameters {
25+
std::shared_ptr<mps_command_queue> command_queue;
26+
int n;
27+
int c_in;
28+
int h_in;
29+
int w_in;
30+
int c_out;
31+
int h_out;
32+
int w_out;
33+
float_array_map config;
34+
float_array_map weights;
35+
};
36+
37+
mps_od_backend(parameters params);
38+
39+
// Training
40+
void set_learning_rate(float lr) override;
41+
float_array_map train(const float_array_map& inputs) override;
42+
43+
// Inference
44+
float_array_map predict(const float_array_map& inputs) const override;
45+
46+
float_array_map export_weights() const override;
47+
48+
private:
49+
void ensure_training_module();
50+
void ensure_prediction_module() const;
51+
52+
parameters params_;
53+
54+
std::unique_ptr<mps_graph_cnn_module> training_module_;
55+
56+
// Cleared whenever the training module is updated.
57+
mutable std::unique_ptr<mps_graph_cnn_module> prediction_module_;
58+
};
59+
60+
} // namespace neural_net
61+
} // namespace turi
62+
63+
#endif // MPS_OD_BACKEND_HPP_

src/ml/neural_net/mps_od_backend.mm

Lines changed: 86 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,86 @@
1+
/* Copyright © 2020 Apple Inc. All rights reserved.
2+
*
3+
* Use of this source code is governed by a BSD-3-clause license that can
4+
* be found in the LICENSE.txt file or at https://opensource.org/licenses/BSD-3-Clause
5+
*/
6+
7+
#include <ml/neural_net/mps_od_backend.hpp>
8+
9+
namespace turi {
10+
namespace neural_net {
11+
12+
void mps_od_backend::ensure_training_module() {
13+
if (training_module_) return;
14+
15+
training_module_.reset(new mps_graph_cnn_module(*params_.command_queue));
16+
training_module_->init(/* network_id */ kODGraphNet, params_.n, params_.c_in, params_.h_in,
17+
params_.w_in, params_.c_out, params_.h_out, params_.w_out, params_.config,
18+
params_.weights);
19+
20+
// Clear params_.weights to free up memory, since they are now superceded by
21+
// whatever the training module contains.
22+
params_.weights.clear();
23+
}
24+
25+
void mps_od_backend::ensure_prediction_module() const {
26+
if (prediction_module_) return;
27+
28+
// Adjust configuration for prediction.
29+
float_array_map config = params_.config;
30+
config["mode"] = shared_float_array::wrap(2.0f);
31+
config["od_include_loss"] = shared_float_array::wrap(0.0f);
32+
33+
// Take weights from training module if present, else from original weights.
34+
float_array_map weights;
35+
if (training_module_) {
36+
weights = training_module_->export_weights();
37+
} else {
38+
weights = params_.weights;
39+
}
40+
41+
prediction_module_.reset(new mps_graph_cnn_module(*params_.command_queue));
42+
prediction_module_->init(/* network_id */ kODGraphNet, params_.n, params_.c_in, params_.h_in,
43+
params_.w_in, params_.c_out, params_.h_out, params_.w_out, config,
44+
weights);
45+
}
46+
47+
mps_od_backend::mps_od_backend(parameters params) : params_(std::move(params)) {
48+
// Immediate instantiate at least one module, since at present we can't
49+
// guarantee that the weights will remain valid after we return.
50+
// TODO: Remove this eager construction once we stop putting weak pointers in
51+
// float_array_map.
52+
if (params_.config.at("mode").data()[0] == 0.f) {
53+
ensure_training_module();
54+
} else {
55+
ensure_prediction_module();
56+
}
57+
}
58+
59+
void mps_od_backend::set_learning_rate(float lr) {
60+
ensure_training_module();
61+
training_module_->set_learning_rate(lr);
62+
}
63+
64+
float_array_map mps_od_backend::train(const float_array_map& inputs) {
65+
// Invalidate prediction_module, since its weights will be stale.
66+
prediction_module_.reset();
67+
68+
ensure_training_module();
69+
return training_module_->train(inputs);
70+
}
71+
72+
float_array_map mps_od_backend::predict(const float_array_map& inputs) const {
73+
ensure_prediction_module();
74+
return prediction_module_->predict(inputs);
75+
}
76+
77+
float_array_map mps_od_backend::export_weights() const {
78+
if (training_module_) {
79+
return training_module_->export_weights();
80+
} else {
81+
return params_.weights;
82+
}
83+
}
84+
85+
} // namespace neural_net
86+
} // namespace turi

0 commit comments

Comments
 (0)