Skip to content
This repository has been archived by the owner on Oct 19, 2024. It is now read-only.

task 801 - launch physical meshes after compilation #938

Open
wants to merge 2 commits into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
7 changes: 6 additions & 1 deletion alpa/create_state_parallel.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,7 @@
from jax.tree_util import tree_flatten, tree_unflatten, PyTreeDef
import numpy as np

from alpa.device_mesh import ReplicatedDistributedArray, PhysicalDeviceMeshGroup
from alpa.device_mesh import ReplicatedDistributedArray, PhysicalDeviceMeshGroup, VirtualMeshGroup
from alpa.global_env import global_config
from alpa.mesh_executable import (NormalMeshDriverExecutable,
GradAccMeshDriverExecutable)
Expand All @@ -30,12 +30,14 @@ class CreateStateExecutable(PipeshardDriverExecutable):

def __init__(self,
mesh_group: PhysicalDeviceMeshGroup,
#virtual_mesh_group: VirtualMeshGroup,
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

remove unused lines. Also applies for some code below

pipeshard_config: PipeshardConfig,
target_placement_specs: Sequence[PlacementSpec],
in_tree: PyTreeDef,
out_tree: Optional[PyTreeDef] = None,
static_argnums: Optional[Sequence[int]] = None):
super().__init__(mesh_group=mesh_group,
#virtual_mesh_group= virtual_mesh_group,
pipeshard_config=pipeshard_config,
num_batch=1,
layer_option=None,
Expand Down Expand Up @@ -134,13 +136,16 @@ def compile_create_state_executable(fun, in_tree, out_tree_thunk,
sliced_eqns)

# Compile a pipeshard executable with predefined output shardings
#pipeshard_config, _ , virtual_mesh_group = compile_pipeshard_executable_internal(
pipeshard_config = compile_pipeshard_executable_internal(
new_jaxpr, None, 1, [False] * len(avals), [False] * len(avals),
executable.mesh_group.parent, 1, "inference",
AutoShardingOption(enable_auto_sharding=False),
UniformStageOption(), name, None, output_shardings, None, None)

return CreateStateExecutable(mesh_group=executable.mesh_group,
#virtual_mesh_group= pipeshard_config.virtual_meshes,
#virtual_mesh_group=virtual_mesh_group,
pipeshard_config=pipeshard_config,
target_placement_specs=placement_specs,
in_tree=in_tree,
Expand Down
156 changes: 152 additions & 4 deletions alpa/device_mesh.py
Original file line number Diff line number Diff line change
Expand Up @@ -56,7 +56,9 @@
update_jax_platform, is_ray_node_resource,
try_import_ray_worker, create_placement_group,
get_bundle_idx, retrieve_placement_group, get_bundle2ip,
check_server_port)
check_server_port, compile_allgather)



ray_worker = try_import_ray_worker()

Expand Down Expand Up @@ -1951,7 +1953,7 @@ def get_physical_mesh(self, mesh_id: int = 0):
mesh_id=mesh_id)
return self.launched_physical_mesh

def get_physical_mesh_group(self, sliced_virtual_meshes):
def get_physical_mesh_group(self, sliced_virtual_meshes, pipeshard_config):
"""Launch a physical mesh group (which will request resources from
Ray)."""
assert self.launched_physical_mesh_group is None, \
Expand All @@ -1972,20 +1974,23 @@ def launch_func(i):
threads[i].join()

self.launched_physical_mesh_group = (PhysicalDeviceMeshGroup(
physical_meshes, self))
physical_meshes, self, pipeshard_config))

return self.launched_physical_mesh_group


class PhysicalDeviceMeshGroup:
"""A list of physical devices that forms a pipeline."""

def __init__(self, meshes: Sequence[DistributedPhysicalDeviceMesh],
parent: VirtualPhysicalMesh):
parent: VirtualPhysicalMesh, pipeshard_config):
self.meshes = list(meshes)
self.parent = parent
self.collective_groups: List[List[Any]] = [
[None for _ in range(len(self))] for _ in range(len(self))
]
#task 801
self.instantiate(pipeshard_config)

def __getitem__(self, index):
return self.meshes[index]
Expand Down Expand Up @@ -2124,6 +2129,77 @@ def _instantiate_nccl_group(cg):
else:
cg.instantiate()

def instantiate(self, pipeshard_config):
from alpa.mesh_executable import UtilMeshWorkerExecutable

virtual_worker_to_rank_map = {}
virtual_to_pysical_map = {}
self.collective_groups = pipeshard_config.virtual_meshes.collective_groups
# task 801 - replacing virtual workers with ray workers
temp_mesh_grp = []
for mesh in self.meshes:
for worker in mesh.workers:
temp_mesh_grp.append(worker)
virtual_worker_to_rank_map = {
worker: r for r, worker in enumerate(temp_mesh_grp)
}
for cgp in self.collective_groups:
for cg in cgp:
if cg is not None:
cg.mesh_workers = temp_mesh_grp
cg.worker_to_rank_map = virtual_worker_to_rank_map
for key, worker in cg.device_str_to_mesh_worker_map.items():
if isinstance(worker, VirtualWorker):
cg.device_str_to_mesh_worker_map[key] = cg.mesh_workers[worker.index]

for virtual_worker, _ in pipeshard_config.instruction_lists.items():
virtual_to_pysical_map[virtual_worker.index] = virtual_worker

pipeshard_config.virtual_worker_to_rank_map = virtual_worker_to_rank_map
pipeshard_config.virtual_to_pysical_map = virtual_to_pysical_map

for resharding_task in pipeshard_config.resharding_tasks:
if global_config.resharding_mode == "send_recv":
task_dones = []
for v_worker, task in resharding_task.sender_tasks.items():
uuid = resharding_task.send_worker_task_ids[v_worker]
worker = resharding_task.collective_group.mesh_workers[v_worker.index]
task_dones.append(
worker.put_resharding_send_task.remote(
uuid, task, resharding_task.collective_group.group_name))
for v_worker, task in resharding_task.receiver_tasks.items():
uuid = resharding_task.recv_worker_task_ids[v_worker]
worker = resharding_task.collective_group.mesh_workers[v_worker.index]
task_dones.append(
worker.put_resharding_recv_task.remote(
uuid, task, resharding_task.collective_group.group_name))
ray.get(task_dones)

task_dones = []
if resharding_task.is_local_allgather_task:
uuid = resharding_task.allgather_uuid
task_spec = resharding_task.task_spec
hlo = compile_allgather(task_spec.aval.shape, task_spec.aval.dtype,
task_spec.dst_sharding_spec,
task_spec.final_dst_spec,
np.prod(resharding_task.dst_mesh.shape))
for v_worker in resharding_task.dst_mesh.workers:
worker = resharding_task.collective_group.mesh_workers[v_worker.index]
task_dones.append(
worker.put_executable.remote(uuid, UtilMeshWorkerExecutable,
hlo))
ray.get(task_dones)
else:
task_dones = []
for v_worker, task in resharding_task._broadcast_tasks.items():
uuid = resharding_task.broadcast_worker_task_ids[v_worker]
worker = resharding_task.collective_group.mesh_workers[v_worker.index]
task_dones.append(
worker.put_resharding_broadcast_task.remote(
uuid, task, resharding_task.collective_group.group_name))
ray.get(task_dones)



########################################
# Device Cluster
Expand Down Expand Up @@ -2305,6 +2381,78 @@ def profile_all(self, *args, **kwargs):
return mesh_profiling.profile_all(self, *args, **kwargs)


#Task 801 - DummyVirtualMesh for interfaces
class VirtualWorker:
def __init__(self, index):
self.index = index
# Additional attributes or methods of virtual workers

class DummyVirtualMesh(VirtualPhysicalMesh):
def __init__(self,
host_ids: Sequence[int],
host_info: Sequence[dict],
num_devices_per_host,
parent: VirtualPhysicalMesh = None,
devices: Sequence[Sequence[int]] = None,
mesh_id: int = None
):
super().__init__(host_ids, host_info, num_devices_per_host, parent, devices)
self.host_ips = []
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This member seems never used. If so, please remove it

self.workers = [] # Virtual workers
self.mesh_id = mesh_id

for host_id in host_ids:
self.host_ips.append(host_info[host_id]['NodeName'])
self.workers.append(VirtualWorker(mesh_id))


#TODO Github Task - VirtualMeshGroup for interfaces
class VirtualMeshGroup:
def __init__(self, sliced_virtual_meshes: List[VirtualPhysicalMesh]):
self.sliced_virtual_meshes = self.get_virtual_meshes(sliced_virtual_meshes)
self.collective_groups: List[List[Any]] = [
[None for _ in range(len(self))] for _ in range(len(self))
]
self.launched_nccl = False

def __getitem__(self, index):
return self.sliced_virtual_meshes[index]

def __len__(self):
return len(self.sliced_virtual_meshes)

def index(self, *args, **kwargs):
return self.sliced_virtual_meshes.index(*args, **kwargs)

def get_virtual_meshes(self, sliced_virtual_meshes):
custom_sliced_virtual_meshes = []
for mesh_idx, mesh in enumerate(sliced_virtual_meshes):
custom_mesh = DummyVirtualMesh(mesh.host_ids, mesh.host_info, mesh.num_devices_per_host, mesh.parent, mesh.devices, mesh_idx)
custom_sliced_virtual_meshes.append(custom_mesh)
return custom_sliced_virtual_meshes

def establish_nccl_group(self,
src_mesh_id: int,
dst_mesh_id: int,
instantiate=False
):
"""Establish NCCL group between two meshes."""
# pylint: disable=import-outside-toplevel
from alpa.pipeline_parallel.cross_mesh_resharding import CollectiveGroup

assert src_mesh_id < dst_mesh_id
if self.collective_groups[src_mesh_id][dst_mesh_id] is not None:
# Already established
return
src_mesh = self.sliced_virtual_meshes[src_mesh_id]
dst_mesh = self.sliced_virtual_meshes[dst_mesh_id]
device_strs = OrderedSet(src_mesh.device_strs + dst_mesh.device_strs)
cg = CollectiveGroup(device_strs, src_mesh, dst_mesh)
self.collective_groups[src_mesh_id][dst_mesh_id] = cg
self.collective_groups[dst_mesh_id][src_mesh_id] = cg



# Global runtime objects
global_cluster: DeviceCluster = None
global_physical_mesh: PhysicalDeviceMesh = None
Expand Down
19 changes: 15 additions & 4 deletions alpa/pipeline_parallel/compile_executable.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,7 @@
from jax.interpreters import pxla
from jax.tree_util import PyTreeDef

from alpa.device_mesh import VirtualPhysicalMesh
from alpa.device_mesh import VirtualPhysicalMesh, VirtualMeshGroup
from alpa.global_env import global_config
from alpa.pipeline_parallel.pipeshard_executable import PipeshardDriverExecutable
from alpa.pipeline_parallel.runtime_emitter import (
Expand Down Expand Up @@ -114,6 +114,10 @@ def compile_pipeshard_executable(
default_as_option, stage_option, name_base, global_input_shardings,
None, stage_input_shardings, parsed_ms_option)

#Task 801
if virtual_mesh.launched_physical_mesh_group is None:
virtual_mesh.get_physical_mesh_group(pipeshard_config.sliced_virtual_meshes, pipeshard_config)

executable = PipeshardDriverExecutable(
mesh_group=virtual_mesh.launched_physical_mesh_group,
pipeshard_config=pipeshard_config,
Expand Down Expand Up @@ -147,6 +151,7 @@ def compile_pipeshard_executable_internal(
stage_input_shardings: Forcibly set sharding specs of input vars of
each stage.
"""
#global virtual_meshes
global_invars = closed_jaxpr.jaxpr.invars
gensym_func = gensym([closed_jaxpr.jaxpr])
inference_mode = (pipeline_schedule == "inference")
Expand Down Expand Up @@ -244,9 +249,13 @@ def compile_pipeshard_executable_internal(
total_flops *= num_microbatch
debug_compilation_time("shard stages")

# Launch the physical mesh group
if virtual_mesh.launched_physical_mesh_group is None:
virtual_mesh.get_physical_mesh_group(sliced_virtual_meshes)
# Launch the virtual mesh group
meshes = VirtualMeshGroup(sliced_virtual_meshes)
else:
# get the already launched physical mesh group
meshes = virtual_mesh.launched_physical_mesh_group

debug_compilation_time("launch meshes")

# Wrap all things into a distributed runtime
Expand All @@ -256,7 +265,8 @@ def compile_pipeshard_executable_internal(
grad_dummy_invars=accumulator_mapping,
global_outvars=global_outvars,
concat_vars_mapping=concat_vars_mapping,
mesh_group=virtual_mesh.launched_physical_mesh_group,
mesh_group=meshes,
sliced_meshes=sliced_virtual_meshes,
schedule=schedule,
is_batch=batch_invars,
num_batch=num_microbatch,
Expand All @@ -277,6 +287,7 @@ def compile_pipeshard_executable_internal(
return pipeshard_config



def split_and_process_layers(closed_jaxpr, full_batch_closed_jaxpr,
num_microbatch, inference_mode, gensym_func):
"""Split and process the input jaxpr with the following steps:
Expand Down
Loading