Skip to content

Commit f092200

Browse files
update runtime extension 1.12 document (#1012)
* fix the package name of Task Example * change the API name space to ipex
1 parent dcabe00 commit f092200

File tree

1 file changed

+24
-22
lines changed

1 file changed

+24
-22
lines changed

docs/tutorials/features/runtime_extension.md

+24-22
Original file line numberDiff line numberDiff line change
@@ -3,9 +3,9 @@ Runtime Extension
33

44
Intel® Extension for PyTorch\* Runtime Extension provides a couple of PyTorch frontend APIs for users to get finer-grained control of the thread runtime. It provides:
55

6-
1. Multi-stream inference via the Python frontend module `intel_extension_for_pytorch.cpu.runtime.MultiStreamModule`.
7-
2. Spawn asynchronous tasks via the Python frontend module `intel_extension_for_pytorch.cpu.runtime.Task`.
8-
3. Program core bindings for OpenMP threads via the Python frontend `intel_extension_for_pytorch.cpu.runtime.pin`.
6+
1. Multi-stream inference via the Python frontend module `ipex.cpu.runtime.MultiStreamModule`.
7+
2. Spawn asynchronous tasks via the Python frontend module `ipex.cpu.runtime.Task`.
8+
3. Program core bindings for OpenMP threads via the Python frontend `ipex.cpu.runtime.pin`.
99

1010
**note**: Intel® Extension for PyTorch\* Runtime extension is in the **experimental** stage. The API is subject to change. More detailed descriptions are available at [API Documentation page](../api_doc.rst).
1111

@@ -27,6 +27,9 @@ If the inputs' batchsize is larger than and divisible by ``num_streams``, the ba
2727

2828
Let's create some ExampleNets that will be used by further examples:
2929
```
30+
import torch
31+
import intel_extension_for_pytorch as ipex
32+
3033
class ExampleNet1(torch.nn.Module):
3134
def __init__(self):
3235
super(ExampleNet1, self).__init__()
@@ -70,8 +73,8 @@ with torch.no_grad():
7073
Here is the example of a model with single tensor input/output. We create a CPUPool with all the cores available on numa node 0. And creating a `MultiStreamModule` with stream number of 2 to do inference.
7174
```
7275
# Convert the model into multi_Stream_model
73-
cpu_pool = intel_extension_for_pytorch.cpu.runtime.CPUPool(node_id=0)
74-
multi_Stream_model = intel_extension_for_pytorch.cpu.runtime.MultiStreamModule(traced_model1, num_streams=2, cpu_pool=cpu_pool)
76+
cpu_pool = ipex.cpu.runtime.CPUPool(node_id=0)
77+
multi_Stream_model = ipex.cpu.runtime.MultiStreamModule(traced_model1, num_streams=2, cpu_pool=cpu_pool)
7578
7679
with torch.no_grad():
7780
y = multi_Stream_model(x)
@@ -81,7 +84,7 @@ with torch.no_grad():
8184
When creating a `MultiStreamModule`, we have default settings for `num_streams` ("AUTO") and `cpu_pool` (with all the cores available on numa node 0). For the `num_streams` of "AUTO", there are limitations to use with int8 datatype as we mentioned in below performance receipts section.
8285
```
8386
# Convert the model into multi_Stream_model
84-
multi_Stream_model = intel_extension_for_pytorch.cpu.runtime.MultiStreamModule(traced_model1)
87+
multi_Stream_model = ipex.cpu.runtime.MultiStreamModule(traced_model1)
8588
8689
with torch.no_grad():
8790
y = multi_Stream_model(x)
@@ -91,17 +94,17 @@ with torch.no_grad():
9194
For module such as ExampleNet2 with structure input/output tensors, user needs to create `MultiStreamModuleHint` as input hint and output hint. `MultiStreamModuleHint` tells `MultiStreamModule` how to auto split the input into streams and concat the output from each steam.
9295
```
9396
# Convert the model into multi_Stream_model
94-
cpu_pool = intel_extension_for_pytorch.cpu.runtime.CPUPool(node_id=0)
97+
cpu_pool = ipex.cpu.runtime.CPUPool(node_id=0)
9598
# Create the input hint object
96-
input_hint = intel_extension_for_pytorch.cpu.runtime.MultiStreamModuleHint(0, 0)
99+
input_hint = ipex.cpu.runtime.MultiStreamModuleHint(0, 0)
97100
# Create the output hint object
98101
# When Python module has multi output tensors, it will be auto pack into a tuple, So we pass a tuple(0, 0) to create the output_hint
99-
output_hint = intel_extension_for_pytorch.cpu.runtime.MultiStreamModuleHint((0, 0))
100-
multi_Stream_model = intel_extension_for_pytorch.cpu.runtime.MultiStreamModule(traced_model2,
101-
num_streams=2,
102-
cpu_pool=cpu_pool,
103-
input_split_hint=input_hint,
104-
output_concat_hint=output_hint)
102+
output_hint = ipex.cpu.runtime.MultiStreamModuleHint((0, 0))
103+
multi_Stream_model = ipex.cpu.runtime.MultiStreamModule(traced_model2,
104+
num_streams=2,
105+
cpu_pool=cpu_pool,
106+
input_split_hint=input_hint,
107+
output_concat_hint=output_hint)
105108
106109
with torch.no_grad():
107110
y = multi_Stream_model(x, x2)
@@ -133,12 +136,11 @@ Here are some performance receipes that we recommend for better multi-stream per
133136
Here is an example for using asynchronous tasks. With the support of a runtime API, you can run 2 modules simultaneously. Each module runs on the corresponding cpu pool.
134137

135138
```
136-
# Create the cpu pool and numa aware memory allocator
137-
cpu_pool1 = ipex.runtime.CPUPool([0, 1, 2, 3])
138-
cpu_pool2 = ipex.runtime.CPUPool([4, 5, 6, 7])
139+
cpu_pool1 = ipex.cpu.runtime.CPUPool([0, 1, 2, 3])
140+
cpu_pool2 = ipex.cpu.runtime.CPUPool([4, 5, 6, 7])
139141
140-
task1 = ipex.runtime.Task(traced_model1, cpu_pool1)
141-
task2 = ipex.runtime.Task(traced_model1, cpu_pool2)
142+
task1 = ipex.cpu.runtime.Task(traced_model1, cpu_pool1)
143+
task2 = ipex.cpu.runtime.Task(traced_model1, cpu_pool2)
142144
143145
y1_future = task1(x)
144146
y2_future = task2(x)
@@ -149,11 +151,11 @@ y2 = y2_future.get()
149151

150152
### Example of configuring core binding
151153

152-
Runtime Extension provides API of `intel_extension_for_pytorch.cpu.runtime.pin` to a CPU Pool for binding physical cores. We can use it without the async task feature. Here is the example to use `intel_extension_for_pytorch.cpu.runtime.pin` in the `with` context.
154+
Runtime Extension provides API of `ipex.cpu.runtime.pin` to a CPU Pool for binding physical cores. We can use it without the async task feature. Here is the example to use `ipex.cpu.runtime.pin` in the `with` context.
153155

154156
```
155-
cpu_pool = intel_extension_for_pytorch.cpu.runtime.CPUPool(node_id=0)
156-
with intel_extension_for_pytorch.cpu.runtime.pin(cpu_pool):
157+
cpu_pool = ipex.cpu.runtime.CPUPool(node_id=0)
158+
with ipex.cpu.runtime.pin(cpu_pool):
157159
y_runtime = traced_model1(x)
158160
```
159161

0 commit comments

Comments
 (0)