Skip to content

Commit

Permalink
[Unity] Implement relax.Function.bind_params (#15626)
Browse files Browse the repository at this point in the history
* [Unity] Implement relax.Function.bind_params

Similar to `relax.Function.bind_symbolic_vars`, implemented in
#15509, this commit introduces
`relax.Function.bind_params` to allow Relax parameters to be
manipulated on a per-function basis.  This utility function and the
existing `BindParams` transform both use the same underlying
implementation.

* Update relay_translator unit tests to avoid duplicate binding

* Updated unit test that attempted to bind non-existent parameter
  • Loading branch information
Lunderberg authored Sep 6, 2023
1 parent c5b7afc commit ec4a8b3
Show file tree
Hide file tree
Showing 10 changed files with 432 additions and 42 deletions.
2 changes: 1 addition & 1 deletion include/tvm/relax/transform.h
Original file line number Diff line number Diff line change
Expand Up @@ -180,7 +180,7 @@ TVM_DLL Pass EliminateCommonSubexpr(bool call_only = false);
*
* \return The Pass.
*/
TVM_DLL Pass BindParams(String func_name, Map<String, runtime::NDArray> params);
TVM_DLL Pass BindParams(String func_name, Map<ObjectRef, ObjectRef> params);

/*!
* \brief Bind symbolic vars to constant shape values.
Expand Down
22 changes: 22 additions & 0 deletions include/tvm/relax/utils.h
Original file line number Diff line number Diff line change
Expand Up @@ -24,7 +24,9 @@
#ifndef TVM_RELAX_UTILS_H_
#define TVM_RELAX_UTILS_H_

#include <tvm/arith/analyzer.h>
#include <tvm/ir/module.h>
#include <tvm/relax/expr.h>
#include <tvm/runtime/logging.h>

namespace tvm {
Expand All @@ -48,6 +50,26 @@ namespace relax {
TVM_DLL Expr Bind(const Expr& expr, const tvm::Map<Var, Expr>& binds,
const tvm::Map<tir::Var, PrimExpr>& symbolic_var_map = {});

/*!
* \brief Infer a binding map for symbolic variables
*
* If a set of relax variables are replaced within an expression, this
* may result in removal of the definition site of a symbolic
* variable. This utility function determines the symbolic variable
* replacements that can be inferred based on the replaced relax
* variables, and can be used alongside the `Bind` utility function to
* replace both the relax variables and the implied symbolic
* variables.
*
* \param binds A map of relax variables to relax expressions
*
* \param analyzer The analyzer to use for simplifications
*
* \return A map of TIR variables to TIR expressions
*/
TVM_DLL tvm::Map<tir::Var, PrimExpr> InferSymbolicVarMap(
const tvm::Map<relax::Var, relax::Expr>& binds, arith::Analyzer* analyzer);

/*!
* \brief Check if the given StructInfo is for a boolean scalar (tensor of rank 0 with a boolean
* dtype).
Expand Down
50 changes: 50 additions & 0 deletions python/tvm/relax/expr.py
Original file line number Diff line number Diff line change
Expand Up @@ -657,6 +657,56 @@ def bind_symbolic_vars(

return _ffi_api.FunctionBindSymbolicVars(self, binding_map) # type: ignore

def bind_params(
self,
binding_map: Mapping[
Union[str, Var],
Union[int, float, PrimExpr, tvm.runtime.NDArray, _np.ndarray, Expr],
],
) -> "Function":
"""Return a new function with updated symbolic variable
Parameters
----------
binding_map: Mapping[
Union[str, Var],
Union[int, float, PrimExpr, tvm.runtime.NDArray, _np.ndarray, Expr],
]
The mapping of values to be replaced.
Keys may be either a `relax.Var` or a string name of the
Relax variable. If the variables are referred to by name,
the name must uniquely identify a parameter in the
function.
Values must be a relax expression, or a value that is
convertible into a relax expression. The value must be
compatible with the variable being replaced.
Returns
-------
func: Function
The updated function
"""

def _normalize_value(value):
# Conversions that must occur prior to the FFI
# conversions.
if isinstance(value, int):
# Relax uses int64 for symbolic variables, but the FFI
# converts python integers into int32.
return tvm.tir.const(value, "int64")
elif isinstance(value, (_np.ndarray, tvm.nd.NDArray)):
return tvm.relax.const(value)
else:
return value

binding_map = {key: _normalize_value(value) for key, value in binding_map.items()}

return _ffi_api.FunctionBindParams(self, binding_map) # type: ignore


@tvm._ffi.register_object("relax.expr.ExternFunc")
class ExternFunc(BaseFunc):
Expand Down
11 changes: 8 additions & 3 deletions python/tvm/relax/transform/transform.py
Original file line number Diff line number Diff line change
Expand Up @@ -405,7 +405,7 @@ def AttachGlobalSymbol() -> tvm.ir.transform.Pass:

def BindParams(
func_name: str,
params: Dict[str, Union[tvm.runtime.NDArray, np.ndarray]],
params: Dict[Union[str, Var], Union[tvm.runtime.NDArray, np.ndarray]],
) -> tvm.ir.transform.Pass:
"""Bind params of function of the module to constant tensors.
Expand All @@ -415,8 +415,13 @@ def BindParams(
func_name: str
The function name to be bound
params : Dict[str, Union[tvm.runtime.NDArray, np.ndarray]]
The map from param name to constant tensors.
params : Dict[
Union[str,relax.Var],
Union[tvm.runtime.NDArray, np.ndarray],
]
The map from parameter or parameter name name to constant
tensors.
Returns
-------
Expand Down
116 changes: 81 additions & 35 deletions src/relax/transform/bind_params.cc
Original file line number Diff line number Diff line change
Expand Up @@ -25,6 +25,7 @@
#include <tvm/relax/type.h>
#include <tvm/tir/op.h>

#include <tuple>
#include <utility>

namespace tvm {
Expand Down Expand Up @@ -81,45 +82,88 @@ void MatchSymbolicVar(const Expr& arg, const Expr& constant,
}
}

std::tuple<Map<Var, Expr>, Map<tir::Var, PrimExpr>> NormalizeBindings(
const Function& func, const Map<ObjectRef, ObjectRef>& untyped_params) {
ICHECK(func.defined());
ICHECK(untyped_params.defined());

// Map from string to the variable(s) with that name.
std::unordered_map<std::string, Array<relax::Var>> string_lookup;
std::unordered_set<const relax::VarNode*> var_set;
for (const auto& param : func->params) {
string_lookup[param->name_hint()].push_back(param);
var_set.insert(param.get());
}

Map<relax::Var, relax::Expr> relax_var_remap;

auto normalize_key = [&](ObjectRef obj) -> relax::Var {
if (auto opt_str = obj.as<String>()) {
std::string str = opt_str.value();
auto it = string_lookup.find(str);
CHECK(it != string_lookup.end())
<< "Function does not have parameter with name \"" << str << "\". "
<< "Function parameters are named "
<< func->params.Map([](const auto& param) { return param->name_hint(); });
CHECK_EQ(it->second.size(), 1)
<< "Function contains multiple parameters with name \"" << str << "\". "
<< "The Relax variables " << it->second << " are all named \"" << str << "\"";
auto var = it->second[0];
CHECK(!relax_var_remap.count(var))
<< "Remap of variable " << var << " was defined multiple times";

return var;
} else if (auto opt_var = obj.as<relax::Var>()) {
auto var = opt_var.value();
CHECK(!relax_var_remap.count(var))
<< "Remap of variable " << var << " was defined multiple times";
CHECK(var_set.count(var.get()))
<< "Function does not use Relax variable " << var << " as a parameter. "
<< "Function parameters are " << func->params;
return var;
} else {
LOG(FATAL)
<< "Expected bound parameter to be a relax::Var, "
<< " or a string that uniquely identifies a relax::Var param within the function. "
<< "However, received object " << obj << " of type " << obj->GetTypeKey();
}
};
auto normalize_value = [&](ObjectRef obj) -> relax::Expr {
if (auto opt = obj.as<relax::Expr>()) {
return opt.value();
} else if (auto opt = obj.as<runtime::NDArray>()) {
return Constant(opt.value());
} else {
LOG(FATAL) << "Cannot coerce object of type " << obj->GetTypeKey()
<< " into relax expression";
}
};

for (const auto& [key, value] : untyped_params) {
relax_var_remap.Set(normalize_key(key), normalize_value(value));
}

arith::Analyzer analyzer;
Map<tir::Var, PrimExpr> symbolic_var_map = InferSymbolicVarMap(relax_var_remap, &analyzer);

// for (const auto& [bind_param, bind_expr] : relax_var_remap) {
// MatchSymbolicVar(bind_param, bind_expr, &symbolic_var_map, &analyzer);
// }

return {relax_var_remap, symbolic_var_map};
}

/*!
* \brief Bind params to function by using name
* \param func Relax function
* \param params params dict
* \return Function
*/
inline Function BindParamsByName(Function func, const Map<String, runtime::NDArray>& params) {
std::unordered_map<std::string, Var> name_dict;
std::unordered_set<Var, ObjectPtrHash, ObjectPtrEqual> repeat_var;
for (auto arg : func->params) {
const auto& name = arg->name_hint();
if (name_dict.count(name)) {
repeat_var.insert(name_dict[name]);
} else {
name_dict[name] = arg;
}
}
Function FunctionBindParams(Function func, const Map<ObjectRef, ObjectRef>& untyped_params) {
auto [bind_dict, symbolic_var_map] = NormalizeBindings(func, untyped_params);

arith::Analyzer analyzer;
Map<Var, Expr> bind_dict;
Map<tir::Var, PrimExpr> symbolic_var_map;

for (auto& kv : params) {
if (name_dict.count(kv.first) == 0) {
continue;
}
const Var& arg = name_dict.at(kv.first);
if (repeat_var.count(arg)) {
LOG(FATAL) << "ValueError: Multiple args in the function have name " << kv.first;
}
Expr const_expr = Constant(kv.second);
bind_dict.Set(arg, const_expr);
MatchSymbolicVar(arg, const_expr, &symbolic_var_map, &analyzer);
}
Expr bound_expr = Bind(func, bind_dict, symbolic_var_map);
Function ret = Downcast<Function>(bound_expr);
ICHECK(ret.defined()) << "The returning type is expected to be a Relax Function."
<< "\n";
return ret;
return Downcast<Function>(bound_expr);
}

/*!
Expand All @@ -129,7 +173,7 @@ inline Function BindParamsByName(Function func, const Map<String, runtime::NDArr
* \param param The param dict
* \return The module after binding params.
*/
IRModule BindParam(IRModule m, String func_name, Map<String, runtime::NDArray> param) {
IRModule BindParam(IRModule m, String func_name, Map<ObjectRef, ObjectRef> bind_params) {
IRModuleNode* new_module = m.CopyOnWrite();
Map<GlobalVar, BaseFunc> functions = m->functions;
for (const auto& func_pr : functions) {
Expand All @@ -138,13 +182,13 @@ IRModule BindParam(IRModule m, String func_name, Map<String, runtime::NDArray> p
// Use global_symbol if it's external linkage
Optional<String> gsymbol = relax_f->GetAttr<String>(tvm::attr::kGlobalSymbol);
if (gsymbol.defined() && gsymbol.value() == func_name) {
Function f_after_bind = BindParamsByName(GetRef<Function>(relax_f), param);
Function f_after_bind = FunctionBindParams(GetRef<Function>(relax_f), bind_params);
new_module->Update(func_pr.first, f_after_bind);
}
} else {
// Use global var's name_hint if it's internal linkage
if (func_pr.first->name_hint == func_name) {
Function f_after_bind = BindParamsByName(GetRef<Function>(relax_f), param);
Function f_after_bind = FunctionBindParams(GetRef<Function>(relax_f), bind_params);
new_module->Update(func_pr.first, f_after_bind);
}
}
Expand All @@ -153,9 +197,11 @@ IRModule BindParam(IRModule m, String func_name, Map<String, runtime::NDArray> p
return GetRef<IRModule>(new_module);
}

TVM_REGISTER_GLOBAL("relax.FunctionBindParams").set_body_typed(FunctionBindParams);

namespace transform {

Pass BindParams(String func_name, Map<String, runtime::NDArray> params) {
Pass BindParams(String func_name, Map<ObjectRef, ObjectRef> params) {
runtime::TypedPackedFunc<IRModule(IRModule, PassContext)> pass_func =
[=](IRModule mod, PassContext pc) { return BindParam(std::move(mod), func_name, params); };
return CreateModulePass(pass_func, 0, "BindParams", {});
Expand Down
56 changes: 56 additions & 0 deletions src/relax/utils.cc
Original file line number Diff line number Diff line change
Expand Up @@ -144,6 +144,62 @@ Expr Bind(const Expr& expr, const tvm::Map<Var, Expr>& binds,
return ExprBinder(binds, symbolic_var_map).VisitExpr(expr);
}

tvm::Map<tir::Var, PrimExpr> InferSymbolicVarMap(
const tvm::Map<relax::Var, relax::Expr>& relax_var_remap, arith::Analyzer* analyzer) {
tvm::Map<tir::Var, PrimExpr> tir_var_remap;

auto bind_from_prim_expr = [&tir_var_remap](const PrimExpr& var_shape,
const PrimExpr& expr_shape) {
if (auto var = var_shape.as<tir::Var>()) {
tir_var_remap.Set(var.value(), expr_shape);
}
};

auto bind_from_shape = [&bind_from_prim_expr](const StructInfo& var, const StructInfo& expr) {
auto var_shape = var.as<ShapeStructInfoNode>();
if (!var_shape) return;
if (!var_shape->values.defined()) return;

auto expr_shape = expr.as<ShapeStructInfoNode>();
CHECK(expr_shape) << "Cannot bind expression with struct type " << expr
<< " to variable with struct type " << var;
if (!expr_shape->values.defined()) return;

auto var_shape_arr = var_shape->values.value();
auto expr_shape_arr = expr_shape->values.value();
CHECK_EQ(var_shape_arr.size(), expr_shape_arr.size())
<< "Cannot bind shape " << expr_shape_arr << " of dimension " << expr_shape_arr.size()
<< " to variable with shape " << var_shape_arr << " of dimension " << var_shape_arr.size();
for (size_t i = 0; i < var_shape_arr.size(); i++) {
bind_from_prim_expr(var_shape_arr[i], expr_shape_arr[i]);
}
};

auto bind_from_tensor = [&bind_from_shape](const StructInfo& var, const StructInfo& expr) {
auto var_tensor = var.as<TensorStructInfoNode>();
if (!var_tensor) return;
if (!var_tensor->shape.defined()) return;

auto expr_tensor = expr.as<TensorStructInfoNode>();
CHECK(expr_tensor) << "Cannot bind expression with struct type " << expr
<< " to variable with struct type " << var;
if (!expr_tensor->shape.defined()) return;

bind_from_shape(GetStructInfo(var_tensor->shape.value()),
GetStructInfo(expr_tensor->shape.value()));
};

for (const auto& [relax_var, relax_expr] : relax_var_remap) {
auto var_sinfo = GetStructInfo(relax_var);
auto expr_sinfo = GetStructInfo(relax_expr);

bind_from_tensor(var_sinfo, expr_sinfo);
bind_from_shape(var_sinfo, expr_sinfo);
}

return tir_var_remap;
}

bool IsBoolStructInfo(const StructInfo& sinfo, bool permit_unknown_rank,
bool permit_unknown_dtype) {
const TensorStructInfoNode* tt = sinfo.as<TensorStructInfoNode>();
Expand Down
Loading

0 comments on commit ec4a8b3

Please sign in to comment.