This repository has been archived by the owner on Oct 19, 2024. It is now read-only.
-
Notifications
You must be signed in to change notification settings - Fork 358
task 801 - launch physical meshes after compilation #938
Open
haifaksh
wants to merge
2
commits into
alpa-projects:main
Choose a base branch
from
haifaksh:haifa-task-801
base: main
Could not load branches
Branch not found: {{ refName }}
Loading
Could not load tags
Nothing to show
Loading
Are you sure you want to change the base?
Some commits from the old base branch may be removed from the timeline,
and old review comments may become outdated.
Open
Changes from all commits
Commits
Show all changes
2 commits
Select commit
Hold shift + click to select a range
File filter
Filter by extension
Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -56,7 +56,9 @@ | |
update_jax_platform, is_ray_node_resource, | ||
try_import_ray_worker, create_placement_group, | ||
get_bundle_idx, retrieve_placement_group, get_bundle2ip, | ||
check_server_port) | ||
check_server_port, compile_allgather) | ||
|
||
|
||
|
||
ray_worker = try_import_ray_worker() | ||
|
||
|
@@ -1951,7 +1953,7 @@ def get_physical_mesh(self, mesh_id: int = 0): | |
mesh_id=mesh_id) | ||
return self.launched_physical_mesh | ||
|
||
def get_physical_mesh_group(self, sliced_virtual_meshes): | ||
def get_physical_mesh_group(self, sliced_virtual_meshes, pipeshard_config): | ||
"""Launch a physical mesh group (which will request resources from | ||
Ray).""" | ||
assert self.launched_physical_mesh_group is None, \ | ||
|
@@ -1972,20 +1974,23 @@ def launch_func(i): | |
threads[i].join() | ||
|
||
self.launched_physical_mesh_group = (PhysicalDeviceMeshGroup( | ||
physical_meshes, self)) | ||
physical_meshes, self, pipeshard_config)) | ||
|
||
return self.launched_physical_mesh_group | ||
|
||
|
||
class PhysicalDeviceMeshGroup: | ||
"""A list of physical devices that forms a pipeline.""" | ||
|
||
def __init__(self, meshes: Sequence[DistributedPhysicalDeviceMesh], | ||
parent: VirtualPhysicalMesh): | ||
parent: VirtualPhysicalMesh, pipeshard_config): | ||
self.meshes = list(meshes) | ||
self.parent = parent | ||
self.collective_groups: List[List[Any]] = [ | ||
[None for _ in range(len(self))] for _ in range(len(self)) | ||
] | ||
#task 801 | ||
self.instantiate(pipeshard_config) | ||
|
||
def __getitem__(self, index): | ||
return self.meshes[index] | ||
|
@@ -2124,6 +2129,77 @@ def _instantiate_nccl_group(cg): | |
else: | ||
cg.instantiate() | ||
|
||
def instantiate(self, pipeshard_config): | ||
from alpa.mesh_executable import UtilMeshWorkerExecutable | ||
|
||
virtual_worker_to_rank_map = {} | ||
virtual_to_pysical_map = {} | ||
self.collective_groups = pipeshard_config.virtual_meshes.collective_groups | ||
# task 801 - replacing virtual workers with ray workers | ||
temp_mesh_grp = [] | ||
for mesh in self.meshes: | ||
for worker in mesh.workers: | ||
temp_mesh_grp.append(worker) | ||
virtual_worker_to_rank_map = { | ||
worker: r for r, worker in enumerate(temp_mesh_grp) | ||
} | ||
for cgp in self.collective_groups: | ||
for cg in cgp: | ||
if cg is not None: | ||
cg.mesh_workers = temp_mesh_grp | ||
cg.worker_to_rank_map = virtual_worker_to_rank_map | ||
for key, worker in cg.device_str_to_mesh_worker_map.items(): | ||
if isinstance(worker, VirtualWorker): | ||
cg.device_str_to_mesh_worker_map[key] = cg.mesh_workers[worker.index] | ||
|
||
for virtual_worker, _ in pipeshard_config.instruction_lists.items(): | ||
virtual_to_pysical_map[virtual_worker.index] = virtual_worker | ||
|
||
pipeshard_config.virtual_worker_to_rank_map = virtual_worker_to_rank_map | ||
pipeshard_config.virtual_to_pysical_map = virtual_to_pysical_map | ||
|
||
for resharding_task in pipeshard_config.resharding_tasks: | ||
if global_config.resharding_mode == "send_recv": | ||
task_dones = [] | ||
for v_worker, task in resharding_task.sender_tasks.items(): | ||
uuid = resharding_task.send_worker_task_ids[v_worker] | ||
worker = resharding_task.collective_group.mesh_workers[v_worker.index] | ||
task_dones.append( | ||
worker.put_resharding_send_task.remote( | ||
uuid, task, resharding_task.collective_group.group_name)) | ||
for v_worker, task in resharding_task.receiver_tasks.items(): | ||
uuid = resharding_task.recv_worker_task_ids[v_worker] | ||
worker = resharding_task.collective_group.mesh_workers[v_worker.index] | ||
task_dones.append( | ||
worker.put_resharding_recv_task.remote( | ||
uuid, task, resharding_task.collective_group.group_name)) | ||
ray.get(task_dones) | ||
|
||
task_dones = [] | ||
if resharding_task.is_local_allgather_task: | ||
uuid = resharding_task.allgather_uuid | ||
task_spec = resharding_task.task_spec | ||
hlo = compile_allgather(task_spec.aval.shape, task_spec.aval.dtype, | ||
task_spec.dst_sharding_spec, | ||
task_spec.final_dst_spec, | ||
np.prod(resharding_task.dst_mesh.shape)) | ||
for v_worker in resharding_task.dst_mesh.workers: | ||
worker = resharding_task.collective_group.mesh_workers[v_worker.index] | ||
task_dones.append( | ||
worker.put_executable.remote(uuid, UtilMeshWorkerExecutable, | ||
hlo)) | ||
ray.get(task_dones) | ||
else: | ||
task_dones = [] | ||
for v_worker, task in resharding_task._broadcast_tasks.items(): | ||
uuid = resharding_task.broadcast_worker_task_ids[v_worker] | ||
worker = resharding_task.collective_group.mesh_workers[v_worker.index] | ||
task_dones.append( | ||
worker.put_resharding_broadcast_task.remote( | ||
uuid, task, resharding_task.collective_group.group_name)) | ||
ray.get(task_dones) | ||
|
||
|
||
|
||
######################################## | ||
# Device Cluster | ||
|
@@ -2305,6 +2381,78 @@ def profile_all(self, *args, **kwargs): | |
return mesh_profiling.profile_all(self, *args, **kwargs) | ||
|
||
|
||
#Task 801 - DummyVirtualMesh for interfaces | ||
class VirtualWorker: | ||
def __init__(self, index): | ||
self.index = index | ||
# Additional attributes or methods of virtual workers | ||
|
||
class DummyVirtualMesh(VirtualPhysicalMesh): | ||
def __init__(self, | ||
host_ids: Sequence[int], | ||
host_info: Sequence[dict], | ||
num_devices_per_host, | ||
parent: VirtualPhysicalMesh = None, | ||
devices: Sequence[Sequence[int]] = None, | ||
mesh_id: int = None | ||
): | ||
super().__init__(host_ids, host_info, num_devices_per_host, parent, devices) | ||
self.host_ips = [] | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. This member seems never used. If so, please remove it |
||
self.workers = [] # Virtual workers | ||
self.mesh_id = mesh_id | ||
|
||
for host_id in host_ids: | ||
self.host_ips.append(host_info[host_id]['NodeName']) | ||
self.workers.append(VirtualWorker(mesh_id)) | ||
|
||
|
||
#TODO Github Task - VirtualMeshGroup for interfaces | ||
class VirtualMeshGroup: | ||
def __init__(self, sliced_virtual_meshes: List[VirtualPhysicalMesh]): | ||
self.sliced_virtual_meshes = self.get_virtual_meshes(sliced_virtual_meshes) | ||
self.collective_groups: List[List[Any]] = [ | ||
[None for _ in range(len(self))] for _ in range(len(self)) | ||
] | ||
self.launched_nccl = False | ||
|
||
def __getitem__(self, index): | ||
return self.sliced_virtual_meshes[index] | ||
|
||
def __len__(self): | ||
return len(self.sliced_virtual_meshes) | ||
|
||
def index(self, *args, **kwargs): | ||
return self.sliced_virtual_meshes.index(*args, **kwargs) | ||
|
||
def get_virtual_meshes(self, sliced_virtual_meshes): | ||
custom_sliced_virtual_meshes = [] | ||
for mesh_idx, mesh in enumerate(sliced_virtual_meshes): | ||
custom_mesh = DummyVirtualMesh(mesh.host_ids, mesh.host_info, mesh.num_devices_per_host, mesh.parent, mesh.devices, mesh_idx) | ||
custom_sliced_virtual_meshes.append(custom_mesh) | ||
return custom_sliced_virtual_meshes | ||
|
||
def establish_nccl_group(self, | ||
src_mesh_id: int, | ||
dst_mesh_id: int, | ||
instantiate=False | ||
): | ||
"""Establish NCCL group between two meshes.""" | ||
# pylint: disable=import-outside-toplevel | ||
from alpa.pipeline_parallel.cross_mesh_resharding import CollectiveGroup | ||
|
||
assert src_mesh_id < dst_mesh_id | ||
if self.collective_groups[src_mesh_id][dst_mesh_id] is not None: | ||
# Already established | ||
return | ||
src_mesh = self.sliced_virtual_meshes[src_mesh_id] | ||
dst_mesh = self.sliced_virtual_meshes[dst_mesh_id] | ||
device_strs = OrderedSet(src_mesh.device_strs + dst_mesh.device_strs) | ||
cg = CollectiveGroup(device_strs, src_mesh, dst_mesh) | ||
self.collective_groups[src_mesh_id][dst_mesh_id] = cg | ||
self.collective_groups[dst_mesh_id][src_mesh_id] = cg | ||
|
||
|
||
|
||
# Global runtime objects | ||
global_cluster: DeviceCluster = None | ||
global_physical_mesh: PhysicalDeviceMesh = None | ||
|
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Oops, something went wrong.
Add this suggestion to a batch that can be applied as a single commit.
This suggestion is invalid because no changes were made to the code.
Suggestions cannot be applied while the pull request is closed.
Suggestions cannot be applied while viewing a subset of changes.
Only one suggestion per line can be applied in a batch.
Add this suggestion to a batch that can be applied as a single commit.
Applying suggestions on deleted lines is not supported.
You must change the existing code in this line in order to create a valid suggestion.
Outdated suggestions cannot be applied.
This suggestion has been applied or marked resolved.
Suggestions cannot be applied from pending reviews.
Suggestions cannot be applied on multi-line comments.
Suggestions cannot be applied while the pull request is queued to merge.
Suggestion cannot be applied right now. Please check back later.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
remove unused lines. Also applies for some code below