diff --git a/include/tvm/relax/transform.h b/include/tvm/relax/transform.h
index 05b26f024212..6d3e92b82245 100644
--- a/include/tvm/relax/transform.h
+++ b/include/tvm/relax/transform.h
@@ -180,7 +180,7 @@ TVM_DLL Pass EliminateCommonSubexpr(bool call_only = false);
  *
  * \return The Pass.
  */
-TVM_DLL Pass BindParams(String func_name, Map<String, runtime::NDArray> params);
+TVM_DLL Pass BindParams(String func_name, Map<ObjectRef, ObjectRef> params);
 
 /*!
  * \brief Bind symbolic vars to constant shape values.
diff --git a/include/tvm/relax/utils.h b/include/tvm/relax/utils.h
index 1a6d5d4a5269..0e0249b863a6 100644
--- a/include/tvm/relax/utils.h
+++ b/include/tvm/relax/utils.h
@@ -24,7 +24,9 @@
 #ifndef TVM_RELAX_UTILS_H_
 #define TVM_RELAX_UTILS_H_
 
+#include <tvm/arith/analyzer.h>
 #include <tvm/ir/module.h>
+#include <tvm/relax/expr.h>
 #include <tvm/runtime/logging.h>
 
 namespace tvm {
@@ -48,6 +50,26 @@ namespace relax {
 TVM_DLL Expr Bind(const Expr& expr, const tvm::Map<Var, Expr>& binds,
                   const tvm::Map<tir::Var, PrimExpr>& symbolic_var_map = {});
 
+/*!
+ * \brief Infer a binding map for symbolic variables
+ *
+ * If a set of relax variables are replaced within an expression, this
+ * may result in removal of the definition site of a symbolic
+ * variable.  This utility function determines the symbolic variable
+ * replacements that can be inferred based on the replaced relax
+ * variables, and can be used alongside the `Bind` utility function to
+ * replace both the relax variables and the implied symbolic
+ * variables.
+ *
+ * \param binds A map of relax variables to relax expressions
+ *
+ * \param analyzer The analyzer to use for simplifications
+ *
+ * \return A map of TIR variables to TIR expressions
+ */
+TVM_DLL tvm::Map<tir::Var, PrimExpr> InferSymbolicVarMap(
+    const tvm::Map<relax::Var, relax::Expr>& binds, arith::Analyzer* analyzer);
+
 /*!
  * \brief Check if the given StructInfo is for a boolean scalar (tensor of rank 0 with a boolean
  * dtype).
diff --git a/python/tvm/relax/expr.py b/python/tvm/relax/expr.py
index 49b91ffb3da1..cd5dfa2863a7 100644
--- a/python/tvm/relax/expr.py
+++ b/python/tvm/relax/expr.py
@@ -657,6 +657,56 @@ def bind_symbolic_vars(
 
         return _ffi_api.FunctionBindSymbolicVars(self, binding_map)  # type: ignore
 
+    def bind_params(
+        self,
+        binding_map: Mapping[
+            Union[str, Var],
+            Union[int, float, PrimExpr, tvm.runtime.NDArray, _np.ndarray, Expr],
+        ],
+    ) -> "Function":
+        """Return a new function with updated symbolic variable
+
+        Parameters
+        ----------
+        binding_map: Mapping[
+                Union[str, Var],
+                Union[int, float, PrimExpr, tvm.runtime.NDArray, _np.ndarray, Expr],
+            ]
+
+            The mapping of values to be replaced.
+
+            Keys may be either a `relax.Var` or a string name of the
+            Relax variable.  If the variables are referred to by name,
+            the name must uniquely identify a parameter in the
+            function.
+
+            Values must be a relax expression, or a value that is
+            convertible into a relax expression.  The value must be
+            compatible with the variable being replaced.
+
+        Returns
+        -------
+        func: Function
+
+            The updated function
+        """
+
+        def _normalize_value(value):
+            # Conversions that must occur prior to the FFI
+            # conversions.
+            if isinstance(value, int):
+                # Relax uses int64 for symbolic variables, but the FFI
+                # converts python integers into int32.
+                return tvm.tir.const(value, "int64")
+            elif isinstance(value, (_np.ndarray, tvm.nd.NDArray)):
+                return tvm.relax.const(value)
+            else:
+                return value
+
+        binding_map = {key: _normalize_value(value) for key, value in binding_map.items()}
+
+        return _ffi_api.FunctionBindParams(self, binding_map)  # type: ignore
+
 
 @tvm._ffi.register_object("relax.expr.ExternFunc")
 class ExternFunc(BaseFunc):
diff --git a/python/tvm/relax/transform/transform.py b/python/tvm/relax/transform/transform.py
index 438a6d1213e8..407805050547 100644
--- a/python/tvm/relax/transform/transform.py
+++ b/python/tvm/relax/transform/transform.py
@@ -387,7 +387,7 @@ def AttachGlobalSymbol() -> tvm.ir.transform.Pass:
 
 def BindParams(
     func_name: str,
-    params: Dict[str, Union[tvm.runtime.NDArray, np.ndarray]],
+    params: Dict[Union[str, Var], Union[tvm.runtime.NDArray, np.ndarray]],
 ) -> tvm.ir.transform.Pass:
     """Bind params of function of the module to constant tensors.
 
@@ -397,8 +397,13 @@ def BindParams(
     func_name: str
         The function name to be bound
 
-    params : Dict[str, Union[tvm.runtime.NDArray, np.ndarray]]
-        The map from param name to constant tensors.
+    params : Dict[
+                Union[str,relax.Var],
+                Union[tvm.runtime.NDArray, np.ndarray],
+             ]
+
+        The map from parameter or parameter name name to constant
+        tensors.
 
     Returns
     -------
diff --git a/src/relax/transform/bind_params.cc b/src/relax/transform/bind_params.cc
index c444a84f44e0..27931b601760 100644
--- a/src/relax/transform/bind_params.cc
+++ b/src/relax/transform/bind_params.cc
@@ -25,6 +25,7 @@
 #include <tvm/relax/type.h>
 #include <tvm/tir/op.h>
 
+#include <tuple>
 #include <utility>
 
 namespace tvm {
@@ -81,45 +82,88 @@ void MatchSymbolicVar(const Expr& arg, const Expr& constant,
   }
 }
 
+std::tuple<Map<Var, Expr>, Map<tir::Var, PrimExpr>> NormalizeBindings(
+    const Function& func, const Map<ObjectRef, ObjectRef>& untyped_params) {
+  ICHECK(func.defined());
+  ICHECK(untyped_params.defined());
+
+  // Map from string to the variable(s) with that name.
+  std::unordered_map<std::string, Array<relax::Var>> string_lookup;
+  std::unordered_set<const relax::VarNode*> var_set;
+  for (const auto& param : func->params) {
+    string_lookup[param->name_hint()].push_back(param);
+    var_set.insert(param.get());
+  }
+
+  Map<relax::Var, relax::Expr> relax_var_remap;
+
+  auto normalize_key = [&](ObjectRef obj) -> relax::Var {
+    if (auto opt_str = obj.as<String>()) {
+      std::string str = opt_str.value();
+      auto it = string_lookup.find(str);
+      CHECK(it != string_lookup.end())
+          << "Function does not have parameter with name \"" << str << "\".  "
+          << "Function parameters are named "
+          << func->params.Map([](const auto& param) { return param->name_hint(); });
+      CHECK_EQ(it->second.size(), 1)
+          << "Function contains multiple parameters with name \"" << str << "\".  "
+          << "The Relax variables " << it->second << " are all named \"" << str << "\"";
+      auto var = it->second[0];
+      CHECK(!relax_var_remap.count(var))
+          << "Remap of variable " << var << " was defined multiple times";
+
+      return var;
+    } else if (auto opt_var = obj.as<relax::Var>()) {
+      auto var = opt_var.value();
+      CHECK(!relax_var_remap.count(var))
+          << "Remap of variable " << var << " was defined multiple times";
+      CHECK(var_set.count(var.get()))
+          << "Function does not use Relax variable " << var << " as a parameter.  "
+          << "Function parameters are " << func->params;
+      return var;
+    } else {
+      LOG(FATAL)
+          << "Expected bound parameter to be a relax::Var, "
+          << " or a string that uniquely identifies a relax::Var param within the function.  "
+          << "However, received object " << obj << " of type " << obj->GetTypeKey();
+    }
+  };
+  auto normalize_value = [&](ObjectRef obj) -> relax::Expr {
+    if (auto opt = obj.as<relax::Expr>()) {
+      return opt.value();
+    } else if (auto opt = obj.as<runtime::NDArray>()) {
+      return Constant(opt.value());
+    } else {
+      LOG(FATAL) << "Cannot coerce object of type " << obj->GetTypeKey()
+                 << " into relax expression";
+    }
+  };
+
+  for (const auto& [key, value] : untyped_params) {
+    relax_var_remap.Set(normalize_key(key), normalize_value(value));
+  }
+
+  arith::Analyzer analyzer;
+  Map<tir::Var, PrimExpr> symbolic_var_map = InferSymbolicVarMap(relax_var_remap, &analyzer);
+
+  // for (const auto& [bind_param, bind_expr] : relax_var_remap) {
+  //   MatchSymbolicVar(bind_param, bind_expr, &symbolic_var_map, &analyzer);
+  // }
+
+  return {relax_var_remap, symbolic_var_map};
+}
+
 /*!
  * \brief Bind params to function by using name
  * \param func Relax function
  * \param params params dict
  * \return Function
  */
-inline Function BindParamsByName(Function func, const Map<String, runtime::NDArray>& params) {
-  std::unordered_map<std::string, Var> name_dict;
-  std::unordered_set<Var, ObjectPtrHash, ObjectPtrEqual> repeat_var;
-  for (auto arg : func->params) {
-    const auto& name = arg->name_hint();
-    if (name_dict.count(name)) {
-      repeat_var.insert(name_dict[name]);
-    } else {
-      name_dict[name] = arg;
-    }
-  }
+Function FunctionBindParams(Function func, const Map<ObjectRef, ObjectRef>& untyped_params) {
+  auto [bind_dict, symbolic_var_map] = NormalizeBindings(func, untyped_params);
 
-  arith::Analyzer analyzer;
-  Map<Var, Expr> bind_dict;
-  Map<tir::Var, PrimExpr> symbolic_var_map;
-
-  for (auto& kv : params) {
-    if (name_dict.count(kv.first) == 0) {
-      continue;
-    }
-    const Var& arg = name_dict.at(kv.first);
-    if (repeat_var.count(arg)) {
-      LOG(FATAL) << "ValueError: Multiple args in the function have name " << kv.first;
-    }
-    Expr const_expr = Constant(kv.second);
-    bind_dict.Set(arg, const_expr);
-    MatchSymbolicVar(arg, const_expr, &symbolic_var_map, &analyzer);
-  }
   Expr bound_expr = Bind(func, bind_dict, symbolic_var_map);
-  Function ret = Downcast<Function>(bound_expr);
-  ICHECK(ret.defined()) << "The returning type is expected to be a Relax Function."
-                        << "\n";
-  return ret;
+  return Downcast<Function>(bound_expr);
 }
 
 /*!
@@ -129,7 +173,7 @@ inline Function BindParamsByName(Function func, const Map<String, runtime::NDArr
  * \param param The param dict
  * \return The module after binding params.
  */
-IRModule BindParam(IRModule m, String func_name, Map<String, runtime::NDArray> param) {
+IRModule BindParam(IRModule m, String func_name, Map<ObjectRef, ObjectRef> bind_params) {
   IRModuleNode* new_module = m.CopyOnWrite();
   Map<GlobalVar, BaseFunc> functions = m->functions;
   for (const auto& func_pr : functions) {
@@ -138,13 +182,13 @@ IRModule BindParam(IRModule m, String func_name, Map<String, runtime::NDArray> p
         // Use global_symbol if it's external linkage
         Optional<String> gsymbol = relax_f->GetAttr<String>(tvm::attr::kGlobalSymbol);
         if (gsymbol.defined() && gsymbol.value() == func_name) {
-          Function f_after_bind = BindParamsByName(GetRef<Function>(relax_f), param);
+          Function f_after_bind = FunctionBindParams(GetRef<Function>(relax_f), bind_params);
           new_module->Update(func_pr.first, f_after_bind);
         }
       } else {
         // Use global var's name_hint if it's internal linkage
         if (func_pr.first->name_hint == func_name) {
-          Function f_after_bind = BindParamsByName(GetRef<Function>(relax_f), param);
+          Function f_after_bind = FunctionBindParams(GetRef<Function>(relax_f), bind_params);
           new_module->Update(func_pr.first, f_after_bind);
         }
       }
@@ -153,9 +197,11 @@ IRModule BindParam(IRModule m, String func_name, Map<String, runtime::NDArray> p
   return GetRef<IRModule>(new_module);
 }
 
+TVM_REGISTER_GLOBAL("relax.FunctionBindParams").set_body_typed(FunctionBindParams);
+
 namespace transform {
 
-Pass BindParams(String func_name, Map<String, runtime::NDArray> params) {
+Pass BindParams(String func_name, Map<ObjectRef, ObjectRef> params) {
   runtime::TypedPackedFunc<IRModule(IRModule, PassContext)> pass_func =
       [=](IRModule mod, PassContext pc) { return BindParam(std::move(mod), func_name, params); };
   return CreateModulePass(pass_func, 0, "BindParams", {});
diff --git a/src/relax/utils.cc b/src/relax/utils.cc
index ccb72805e371..f8235def240b 100644
--- a/src/relax/utils.cc
+++ b/src/relax/utils.cc
@@ -144,6 +144,62 @@ Expr Bind(const Expr& expr, const tvm::Map<Var, Expr>& binds,
   return ExprBinder(binds, symbolic_var_map).VisitExpr(expr);
 }
 
+tvm::Map<tir::Var, PrimExpr> InferSymbolicVarMap(
+    const tvm::Map<relax::Var, relax::Expr>& relax_var_remap, arith::Analyzer* analyzer) {
+  tvm::Map<tir::Var, PrimExpr> tir_var_remap;
+
+  auto bind_from_prim_expr = [&tir_var_remap](const PrimExpr& var_shape,
+                                              const PrimExpr& expr_shape) {
+    if (auto var = var_shape.as<tir::Var>()) {
+      tir_var_remap.Set(var.value(), expr_shape);
+    }
+  };
+
+  auto bind_from_shape = [&bind_from_prim_expr](const StructInfo& var, const StructInfo& expr) {
+    auto var_shape = var.as<ShapeStructInfoNode>();
+    if (!var_shape) return;
+    if (!var_shape->values.defined()) return;
+
+    auto expr_shape = expr.as<ShapeStructInfoNode>();
+    CHECK(expr_shape) << "Cannot bind expression with struct type " << expr
+                      << " to variable with struct type " << var;
+    if (!expr_shape->values.defined()) return;
+
+    auto var_shape_arr = var_shape->values.value();
+    auto expr_shape_arr = expr_shape->values.value();
+    CHECK_EQ(var_shape_arr.size(), expr_shape_arr.size())
+        << "Cannot bind shape " << expr_shape_arr << " of dimension " << expr_shape_arr.size()
+        << " to variable with shape " << var_shape_arr << " of dimension " << var_shape_arr.size();
+    for (size_t i = 0; i < var_shape_arr.size(); i++) {
+      bind_from_prim_expr(var_shape_arr[i], expr_shape_arr[i]);
+    }
+  };
+
+  auto bind_from_tensor = [&bind_from_shape](const StructInfo& var, const StructInfo& expr) {
+    auto var_tensor = var.as<TensorStructInfoNode>();
+    if (!var_tensor) return;
+    if (!var_tensor->shape.defined()) return;
+
+    auto expr_tensor = expr.as<TensorStructInfoNode>();
+    CHECK(expr_tensor) << "Cannot bind expression with struct type " << expr
+                       << " to variable with struct type " << var;
+    if (!expr_tensor->shape.defined()) return;
+
+    bind_from_shape(GetStructInfo(var_tensor->shape.value()),
+                    GetStructInfo(expr_tensor->shape.value()));
+  };
+
+  for (const auto& [relax_var, relax_expr] : relax_var_remap) {
+    auto var_sinfo = GetStructInfo(relax_var);
+    auto expr_sinfo = GetStructInfo(relax_expr);
+
+    bind_from_tensor(var_sinfo, expr_sinfo);
+    bind_from_shape(var_sinfo, expr_sinfo);
+  }
+
+  return tir_var_remap;
+}
+
 bool IsBoolStructInfo(const StructInfo& sinfo, bool permit_unknown_rank,
                       bool permit_unknown_dtype) {
   const TensorStructInfoNode* tt = sinfo.as<TensorStructInfoNode>();
diff --git a/tests/python/relax/test_bind_params.py b/tests/python/relax/test_bind_params.py
new file mode 100644
index 000000000000..a92e4fe8e510
--- /dev/null
+++ b/tests/python/relax/test_bind_params.py
@@ -0,0 +1,156 @@
+# Licensed to the Apache Software Foundation (ASF) under one
+# or more contributor license agreements.  See the NOTICE file
+# distributed with this work for additional information
+# regarding copyright ownership.  The ASF licenses this file
+# to you under the Apache License, Version 2.0 (the
+# "License"); you may not use this file except in compliance
+# with the License.  You may obtain a copy of the License at
+#
+#   http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing,
+# software distributed under the License is distributed on an
+# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+# KIND, either express or implied.  See the License for the
+# specific language governing permissions and limitations
+# under the License.
+
+
+import tvm
+import tvm.script
+import tvm.testing
+from tvm import relax, tir
+from tvm.script import relax as R
+
+import numpy as np
+import pytest
+
+param_specification = tvm.testing.parameter("by_string", "by_var")
+param_shape = tvm.testing.parameter("static_shape", "dynamic_shape", "ndim", "arbitrary")
+tensor_param_dtype = tvm.testing.parameter("float32", None)
+
+
+def test_bind_tensor_param(param_specification, param_shape, tensor_param_dtype):
+    if param_shape == "static_shape":
+        shape = [16]
+        ndim = -1
+    elif param_shape == "dynamic_shape":
+        shape = [tir.Var("N", "int64")]
+        ndim = -1
+    elif param_shape == "ndim":
+        shape = None
+        ndim = 1
+    elif param_shape == "arbitrary":
+        shape = None
+        ndim = -1
+    else:
+        raise ValueError(f"Unknown param_shape: {param_shape}")
+
+    @R.function
+    def before(A: R.Tensor(shape, ndim=ndim, dtype=tensor_param_dtype)):
+        R.func_attr({"global_symbol": "main"})
+        B: R.Tensor(shape=shape, ndim=ndim, dtype=tensor_param_dtype) = A
+        out = R.add(B, B)
+        return out
+
+    np_data = np.arange(16).astype("float32")
+    inlined_relax_const = relax.const(np_data)
+
+    @R.function
+    def expected() -> R.Tensor([16], "float32"):
+        R.func_attr({"global_symbol": "main"})
+        B = inlined_relax_const
+        out = R.add(B, B)
+        return out
+
+    if param_specification == "by_string":
+        var = "A"
+    elif param_specification == "by_var":
+        var = before.params[0]
+    else:
+        raise ValueError("Unknown param_specification: {param_specification}")
+
+    after = before.bind_params({var: np.arange(16).astype("float32")})
+
+    tvm.ir.assert_structural_equal(expected, after)
+
+
+def test_bind_shape_param(param_shape):
+    if param_shape == "static_shape":
+        shape = [16]
+        ndim = -1
+    elif param_shape == "dynamic_shape":
+        shape = [tir.Var("N", "int64")]
+        ndim = -1
+    elif param_shape == "ndim":
+        shape = None
+        ndim = 1
+    elif param_shape == "arbitrary":
+        shape = None
+        ndim = -1
+    else:
+        raise ValueError(f"Unknown param_shape: {param_shape}")
+
+    @R.function
+    def before(A: R.Shape(shape, ndim=ndim)):
+        R.func_attr({"global_symbol": "main"})
+        B: R.Shape(shape, ndim=ndim) = A
+        return B
+
+    @R.function
+    def expected() -> R.Shape([16]):
+        R.func_attr({"global_symbol": "main"})
+        B = R.ShapeExpr([16])
+        return B
+
+    after = before.bind_params({"A": relax.ShapeExpr([16])})
+
+    tvm.ir.assert_structural_equal(expected, after)
+
+
+prim_value_dtype = tvm.testing.parameter("int64", "int32", "float32")
+
+
+@pytest.mark.xfail(reason="Depends on relax.PrimValue holding a tir.PrimExpr, PR#15577")
+def test_bind_prim_value(prim_value_dtype):
+    @R.function
+    def before(A: R.Prim(value="N", dtype=prim_value_dtype)):
+        R.func_attr({"global_symbol": "main"})
+        B: R.Prim(value="N", dtype=prim_value_dtype) = A
+        return B
+
+    @R.function
+    def expected() -> R.Prim(value=16, dtype=prim_value_dtype):
+        R.func_attr({"global_symbol": "main"})
+        B = R.PrimValue(value=16, dtype=dtype)
+        return B
+
+    after = before.bind_params({"A": relax.PrimValue(tir.const(16, prim_value_dtype))})
+
+    tvm.ir.assert_structural_equal(expected, after)
+
+
+def test_error_on_unknown_var():
+    @R.function
+    def before(A: R.Tensor([16], dtype="float32")):
+        R.func_attr({"global_symbol": "main"})
+        return A
+
+    unknown_var = relax.Var("unknown_var")
+
+    with pytest.raises(tvm.TVMError):
+        before.bind_params({unknown_var: np.arange(16).astype("float32")})
+
+
+def test_error_on_unknown_var_name():
+    @R.function
+    def before(A: R.Tensor([16], dtype="float32")):
+        R.func_attr({"global_symbol": "main"})
+        return A
+
+    with pytest.raises(tvm.TVMError):
+        before.bind_params({"unknown_var_name": np.arange(16).astype("float32")})
+
+
+if __name__ == "__main__":
+    tvm.testing.main()
diff --git a/tests/python/relax/test_relay_translator.py b/tests/python/relax/test_relay_translator.py
index 54cd1b243dbf..c752fa5e1015 100644
--- a/tests/python/relax/test_relay_translator.py
+++ b/tests/python/relax/test_relay_translator.py
@@ -126,10 +126,14 @@ def test_verify_e2e_translation_gpu(layout, batch_size, image_shape):
 def verify_extracted_tasks(target_str, layout, batch_size, image_shape, module_equality):
     target = Target(target_str)
     relay_mod, params = get_resnet(batch_size, "float32", layout, image_shape)
+    # Parameters can be bound either as part of the `from_relay`
+    # conversion, or as part of the `extract_tasks` method.  However,
+    # they shouldn't be used in both locations, because
+    # `relax.BindParams` validates that there exists an unbound
+    # parameter of the specified name.
     relax_mod = relay_translator.from_relay(
         relay_mod["main"],
         target,
-        params,
         pass_config={
             "relay.backend.use_meta_schedule": True,
             "relay.FuseOps.max_depth": 1,  # Disable relay fusion
diff --git a/tests/python/relax/test_transform_bind_params.py b/tests/python/relax/test_transform_bind_params.py
index 8e760b6fd70f..9e212693f969 100644
--- a/tests/python/relax/test_transform_bind_params.py
+++ b/tests/python/relax/test_transform_bind_params.py
@@ -123,5 +123,57 @@ def main(
     )
 
 
+param_specification = tvm.testing.parameter("by_string", "by_var")
+
+
+def test_bind_params_by_var_obj(param_specification):
+    @tvm.script.ir_module
+    class Before:
+        @R.function
+        def main(A: R.Tensor([16], "float32")):
+            return A
+
+    np_data = np.arange(16).astype("float32")
+    inlined_relax_const = relax.const(np_data)
+
+    @tvm.script.ir_module
+    class Expected:
+        @R.function
+        def main():
+            return inlined_relax_const
+
+    if param_specification == "by_string":
+        var = "A"
+    elif param_specification == "by_var":
+        var = Before["main"].params[0]
+    else:
+        raise ValueError("Unknown param_specification: {param_specification}")
+
+    After = relax.transform.BindParams("main", {var: np_data})(Before)
+
+    tvm.ir.assert_structural_equal(Expected, After)
+
+
+def test_bind_params_by_var_name():
+    @tvm.script.ir_module
+    class Before:
+        @R.function
+        def main(A: R.Tensor([16], "float32")):
+            return A
+
+    np_data = np.arange(16).astype("float32")
+    inlined_relax_const = relax.const(np_data)
+
+    @tvm.script.ir_module
+    class Expected:
+        @R.function
+        def main():
+            return inlined_relax_const
+
+    After = relax.transform.BindParams("main", {"A": np_data})(Before)
+
+    tvm.ir.assert_structural_equal(Expected, After)
+
+
 if __name__ == "__main__":
     tvm.testing.main()
diff --git a/tests/python/relax/test_transform_fold_constant.py b/tests/python/relax/test_transform_fold_constant.py
index b8ad5c4487d3..c2a3bd50922b 100644
--- a/tests/python/relax/test_transform_fold_constant.py
+++ b/tests/python/relax/test_transform_fold_constant.py
@@ -378,8 +378,7 @@ def expected(data: R.Tensor((256,), "float32")) -> R.Tensor((16, 16), dtype="flo
     before = gen_mod(Module, "before", {"c0": c0_np, "c1": c1_np})
     assert relax.analysis.well_formed(before)
 
-    c2_np = np.multiply(np.add(c0_np, c0_np), c1_np)
-    expected = gen_mod(Module, "expected", {"c2": c2_np})
+    expected = gen_mod(Module, "expected", {})
 
     after = relax.transform.FoldConstant()(before)
     tvm.ir.assert_structural_equal(after, expected)