Skip to content

Commit

Permalink
Update on "[ET-VK] Adding batch processing in x axis to conv2d dw sha…
Browse files Browse the repository at this point in the history
…der by caching input texel for reuse."

This diff adds batch processing in the x axis to the conv2d dw shader by reusing input texel overlapping between consecutive tiles. The changes include modifying the glsl code for the conv2d dw output tile, adding a new parameter to the yaml file, and modifying the Convolution.cpp file to use the new parameter.

Differential Revision: [D67868671](https://our.internmc.facebook.com/intern/diff/D67868671/)

[ghstack-poisoned]
  • Loading branch information
trivedivivek committed Jan 7, 2025
2 parents b6d7a76 + 7260da1 commit 5cd6447
Show file tree
Hide file tree
Showing 18 changed files with 133 additions and 57 deletions.
4 changes: 2 additions & 2 deletions .lintrunner.toml
Original file line number Diff line number Diff line change
Expand Up @@ -302,8 +302,8 @@ include_patterns = [
'profiler/**/*.py',
'runtime/**/*.py',
'scripts/**/*.py',
# 'test/**/*.py',
# 'util/**/*.py',
'test/**/*.py',
'util/**/*.py',
'*.py',
]
exclude_patterns = [
Expand Down
12 changes: 11 additions & 1 deletion .mypy.ini
Original file line number Diff line number Diff line change
Expand Up @@ -21,10 +21,14 @@ files =
profiler,
runtime,
scripts,
test,
util

mypy_path = executorch

[mypy-executorch.backends.*]
follow_untyped_imports = True

[mypy-executorch.codegen.*]
follow_untyped_imports = True

Expand All @@ -46,6 +50,12 @@ follow_untyped_imports = True
[mypy-executorch.runtime.*]
follow_untyped_imports = True

[mypy-executorch.test.*]
follow_untyped_imports = True

[mypy-functorch.*]
follow_untyped_imports = True

[mypy-requests.*]
follow_untyped_imports = True

Expand Down Expand Up @@ -80,4 +90,4 @@ ignore_missing_imports = True
ignore_missing_imports = True

[mypy-zstd]
ignore_missing_imports = True
ignore_missing_imports = True
10 changes: 1 addition & 9 deletions backends/cadence/fusion_g3/operators/op_add.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,7 @@

#include <xa_nnlib_kernels_api.h>

#include <executorch/backends/cadence/fusion_g3/operators/xt_macros.h>
#include <executorch/kernels/portable/cpu/scalar_utils.h>
#include <executorch/kernels/portable/cpu/util/elementwise_util.h>
#include <executorch/kernels/portable/cpu/util/kernel_ops_util.h>
Expand All @@ -28,15 +29,6 @@ namespace impl {
namespace G3 {
namespace native {

#define XT_KERNEL_CHECK(ctx, out, kernel, ...) \
const auto ret = kernel(__VA_ARGS__); \
ET_KERNEL_CHECK_MSG( \
ctx, \
ret == 0, \
InvalidArgument, \
out, \
"Failed to run kernel: " #kernel "(" #__VA_ARGS__ ")");

Tensor& add_out(
KernelRuntimeContext& ctx,
const Tensor& a,
Expand Down
13 changes: 13 additions & 0 deletions backends/cadence/fusion_g3/operators/targets.bzl
Original file line number Diff line number Diff line change
Expand Up @@ -27,6 +27,7 @@ def define_operator(name: str, deps: list[str] | None = None) -> None:
deps = deps + common_deps,
exported_deps = [
":operators_header",
":xt_macros",
],
)

Expand Down Expand Up @@ -61,5 +62,17 @@ def define_common_targets():
],
)

runtime.cxx_library(
name = "xt_macros",
exported_headers = ["xt_macros.h"],
visibility = [
"//executorch/backends/cadence/...",
],
exported_deps = [
"//executorch/runtime/core/exec_aten:lib",
"//executorch/runtime/kernel:kernel_runtime_context",
],
)

for op in OPERATORS:
define_operator(op)
20 changes: 20 additions & 0 deletions backends/cadence/fusion_g3/operators/xt_macros.h
Original file line number Diff line number Diff line change
@@ -0,0 +1,20 @@
/*
* Copyright (c) Meta Platforms, Inc. and affiliates.
* All rights reserved.
*
* This source code is licensed under the BSD-style license found in the
* LICENSE file in the root directory of this source tree.
*/

#pragma once

#include <executorch/runtime/core/exec_aten/util/tensor_util.h>

#define XT_KERNEL_CHECK(ctx, out, kernel, ...) \
const auto ret = kernel(__VA_ARGS__); \
ET_KERNEL_CHECK_MSG( \
ctx, \
ret == 0, \
InvalidArgument, \
out, \
"Failed to run kernel: " #kernel "(" #__VA_ARGS__ ")");
Original file line number Diff line number Diff line change
Expand Up @@ -100,7 +100,7 @@ void main() {
}

// accumulate dot product in 1st sum only until tile size
if (i < int(TILE_SIZE)) {
if (i < TILE_SIZE) {
for (int j = 0; j < TILE_SIZE; j++, kx++) {
prev_kernel_line[j] = texelFetch(t_kernel, ivec2(kx, pos.z), 0);
for (int s = 0; s < BATCH_SIZE_X; s++) {
Expand Down
2 changes: 1 addition & 1 deletion docs/TARGETS
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
load("@fbcode_macros//build_defs:native_rules.bzl", "buck_filegroup", "buck_sh_test")
load("@fbcode_macros//build_defs:python_binary.bzl", "python_binary")

oncall("pytorch_r2p")
oncall("executorch")

python_binary(
name = "sphinx",
Expand Down
16 changes: 16 additions & 0 deletions extension/flat_tensor/serialize/TARGETS
Original file line number Diff line number Diff line change
@@ -0,0 +1,16 @@
load("@fbsource//xplat/executorch/build:runtime_wrapper.bzl", "runtime")
load(":targets.bzl", "define_common_targets")

oncall("executorch")

define_common_targets()

runtime.python_library(
name = "schema",
srcs = [
"flat_tensor_schema.py",
],
visibility = [
"//executorch/...",
],
)
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,7 @@ table TensorMetadata {
scalar_type: executorch_flatbuffer.ScalarType;

// Size of each dimension.
dim_sizes: [int32];
sizes: [int32];

// Specifies in what order the dimensions are laid out in memory (from outer
// to inner).
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,7 @@
class TensorMetadata:
fully_qualified_name: str
scalar_type: ScalarType
dim_sizes: List[int]
sizes: List[int]
dim_order: List[bytes]

segment_index: int
Expand Down
File renamed without changes.
36 changes: 36 additions & 0 deletions extension/flat_tensor/serialize/targets.bzl
Original file line number Diff line number Diff line change
@@ -0,0 +1,36 @@
load("@fbsource//xplat/executorch/build:runtime_wrapper.bzl", "runtime")

def define_common_targets():
runtime.genrule(
name = "gen_schema",
srcs = [
"flat_tensor.fbs",
"scalar_type.fbs",
],
outs = {
"schema_generated.h": ["flat_tensor_generated.h"],
"scalar_type_generated.h": ["scalar_type_generated.h"]
},
cmd = " ".join([
"$(exe {})".format(runtime.external_dep_location("flatc")),
"--cpp",
"--cpp-std c++11",
"--scoped-enums",
"-o ${OUT}",
"${SRCS}",
]),
default_outs = ["."],
)

runtime.cxx_library(
name = "generated_headers",
srcs = [],
visibility = [
"//executorch/...",
],
exported_headers = {
"schema_generated.h": ":gen_schema[schema_generated.h]",
"scalar_type_generated.h": ":gen_schema[scalar_type_generated.h]",
},
exported_external_deps = ["flatbuffers-api"],
)
20 changes: 6 additions & 14 deletions test/end2end/exported_module.py
Original file line number Diff line number Diff line change
Expand Up @@ -126,9 +126,7 @@ def return_wrapper():
trace_inputs_method = "get_upper_bound_inputs"
get_trace_inputs = get_inputs_adapter(
(
# pyre-fixme[6]: For 1st argument expected `(...) -> Any` but got
# `Union[Module, Tensor]`.
getattr(eager_module, trace_inputs_method)
getattr(eager_module, trace_inputs_method) # type: ignore[arg-type]
if hasattr(eager_module, trace_inputs_method)
else eager_module.get_random_inputs
),
Expand All @@ -144,18 +142,14 @@ def return_wrapper():
if hasattr(eager_module, "get_dynamic_shapes"):
assert capture_config is not None
assert capture_config.enable_aot is True
# pyre-fixme[29]: `Union[nn.modules.module.Module,
# torch._tensor.Tensor]` is not a function.
trace_dynamic_shapes = eager_module.get_dynamic_shapes()
trace_dynamic_shapes = eager_module.get_dynamic_shapes() # type: ignore[operator]
method_name_to_dynamic_shapes = {}
for method in methods:
method_name_to_dynamic_shapes[method] = trace_dynamic_shapes

memory_planning_pass = MemoryPlanningPass()
if hasattr(eager_module, "get_memory_planning_pass"):
# pyre-fixme[29]: `Union[nn.modules.module.Module,
# torch._tensor.Tensor]` is not a function.
memory_planning_pass = eager_module.get_memory_planning_pass()
memory_planning_pass = eager_module.get_memory_planning_pass() # type: ignore[operator]

class WrapperModule(nn.Module):
def __init__(self, method):
Expand All @@ -172,7 +166,7 @@ def __init__(self, method):
assert method_name == "forward"
ep = _export(
eager_module,
method_input,
method_input, # type: ignore[arg-type]
dynamic_shapes=(
method_name_to_dynamic_shapes[method_name]
if method_name_to_dynamic_shapes
Expand All @@ -184,7 +178,7 @@ def __init__(self, method):
else:
exported_methods[method_name] = export(
eager_module,
method_input,
method_input, # type: ignore[arg-type]
dynamic_shapes=(
method_name_to_dynamic_shapes[method_name]
if method_name_to_dynamic_shapes
Expand Down Expand Up @@ -220,9 +214,7 @@ def __init__(self, method):

# Get a function that creates random inputs appropriate for testing.
get_random_inputs_fn = get_inputs_adapter(
# pyre-fixme[6]: For 1st argument expected `(...) -> Any` but got
# `Union[Module, Tensor]`.
eager_module.get_random_inputs,
eager_module.get_random_inputs, # type: ignore[arg-type]
# all exported methods must have the same signature so just pick the first one.
methods[0],
)
Expand Down
6 changes: 1 addition & 5 deletions test/end2end/test_end2end.py
Original file line number Diff line number Diff line change
Expand Up @@ -52,9 +52,7 @@
kernel_mode = None # either aten mode or lean mode
try:
from executorch.extension.pybindings.portable_lib import (
_load_bundled_program_from_buffer,
_load_for_executorch_from_buffer,
_load_for_executorch_from_bundled_program,
)

kernel_mode = "lean"
Expand All @@ -63,10 +61,8 @@
pass

try:
from executorch.extension.pybindings.aten_lib import (
_load_bundled_program_from_buffer,
from executorch.extension.pybindings.aten_lib import ( # type: ignore[import-not-found]
_load_for_executorch_from_buffer,
_load_for_executorch_from_bundled_program,
)

assert kernel_mode is None
Expand Down
6 changes: 2 additions & 4 deletions test/models/export_delegated_program.py
Original file line number Diff line number Diff line change
Expand Up @@ -118,9 +118,7 @@ def export_module_to_program(
eager_module = module_class().eval()
inputs = ()
if hasattr(eager_module, "get_random_inputs"):
# pyre-fixme[29]: `Union[nn.modules.module.Module, torch._tensor.Tensor]` is
# not a function.
inputs = eager_module.get_random_inputs()
inputs = eager_module.get_random_inputs() # type: ignore[operator]

class WrapperModule(torch.nn.Module):
def __init__(self, fn):
Expand Down Expand Up @@ -153,7 +151,7 @@ def forward(self, *args, **kwargs):
).to_executorch(config=et_config)
else:
edge: exir.EdgeProgramManager = to_edge(exported_program)
lowered_module = to_backend(
lowered_module = to_backend( # type: ignore[call-arg]
backend_id, edge.exported_program(), compile_specs=[]
)

Expand Down
4 changes: 3 additions & 1 deletion test/models/generate_linear_out_bundled_program.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,7 +27,9 @@
from executorch.exir.passes import MemoryPlanningPass, ToOutVarPass
from executorch.exir.print_program import pretty_print

from executorch.test.models.linear_model import LinearModel
from executorch.test.models.linear_model import ( # type: ignore[import-not-found]
LinearModel,
)
from torch.export import export


Expand Down
21 changes: 10 additions & 11 deletions util/activation_memory_profiler.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,7 @@
import json
import typing
from dataclasses import dataclass, field
from typing import List
from typing import Any, Dict, List, Optional

import executorch.exir.memory as memory
import torch
Expand Down Expand Up @@ -52,7 +52,7 @@ def create_tensor_allocation_info(graph: torch.fx.Graph) -> List[MemoryTimeline]
allocations at that timestep.
"""
nodes = graph.nodes
memory_timeline = [None] * len(nodes)
memory_timeline: List[Optional[MemoryTimeline]] = [None for _ in range(len(nodes))]
for _, node in enumerate(nodes):
if node.op == "output":
continue
Expand All @@ -72,11 +72,11 @@ def create_tensor_allocation_info(graph: torch.fx.Graph) -> List[MemoryTimeline]
stack_trace = node.meta.get("stack_trace")
fqn = _get_module_hierarchy(node)
for j in range(start, end + 1):
if memory_timeline[j] is None:
# pyre-ignore
memory_timeline[j] = MemoryTimeline()
# pyre-ignore
memory_timeline[j].allocations.append(
memory_timeline_j = memory_timeline[j]
if memory_timeline_j is None:
memory_timeline_j = MemoryTimeline()
assert memory_timeline_j
memory_timeline_j.allocations.append(
Allocation(
node.name,
node.target,
Expand All @@ -87,8 +87,7 @@ def create_tensor_allocation_info(graph: torch.fx.Graph) -> List[MemoryTimeline]
stack_trace,
)
)
# pyre-ignore
return memory_timeline
return memory_timeline # type: ignore[return-value]


def _validate_memory_planning_is_done(exported_program: ExportedProgram):
Expand Down Expand Up @@ -129,7 +128,7 @@ def generate_memory_trace(

memory_timeline = create_tensor_allocation_info(exported_program.graph)
root = {}
trace_events = []
trace_events: List[Dict[str, Any]] = []
root["traceEvents"] = trace_events

tid = 0
Expand All @@ -138,7 +137,7 @@ def generate_memory_trace(
if memory_timeline_event is None:
continue
for allocation in memory_timeline_event.allocations:
e = {}
e: Dict[str, Any] = {}
e["name"] = allocation.name
e["cat"] = "memory_allocation"
e["ph"] = "X"
Expand Down
Loading

0 comments on commit 5cd6447

Please sign in to comment.