Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

add commit #1

Open
wants to merge 1 commit into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
36 changes: 32 additions & 4 deletions alpa/device_mesh.py
Original file line number Diff line number Diff line change
Expand Up @@ -64,7 +64,7 @@
from alpa.collective import worker_nccl_util

logger = logging.getLogger(__name__)
logger.setLevel(logging.INFO)
logger.setLevel(logging.DEBUG)

ReshardingTileSpec = namedtuple("ReshardingTileSpec",
["offset", "rank", "gpu_idx"])
Expand Down Expand Up @@ -124,6 +124,7 @@ def __init__(self, server_address: str, num_hosts: int, host_id: int,
self.distributed_client.connect()
logger.debug(
f"{host_id}: Success to connect to xla runtime at {server_address}")
print("MeshHostWorker: successfully connect to XLA server")
if global_config.backend == "gpu":
self.backend = xla_client.make_gpu_client(self.distributed_client,
node_id=host_id)
Expand Down Expand Up @@ -1005,10 +1006,12 @@ def __init__(self,
pass

if found_existing_workers:
print("found_existing_workers")
self.service_server = None
self.workers = self.connect_to_existing_workers()
self.launched = False
else:
print("not found_existing_workers, launch xla servers")
self.service_server, self.workers = self.launch_xla_servers()
self.launched = True

Expand All @@ -1017,6 +1020,7 @@ def __init__(self,

def get_host_worker_name(self, host_id):
if self.namespace:
print("worker: ",f"mesh_{self.mesh_id}_host_{host_id}")
return f"mesh_{self.mesh_id}_host_{host_id}"
else:
return None
Expand All @@ -1042,7 +1046,8 @@ def launch_xla_servers(self):
service_server = xla_client._xla.get_distributed_runtime_service(
server_address, self.num_hosts, use_coordination_service=False)
logger.debug(f"Success to start XLA gRPC server on port: {port}...")
time.sleep(0.4)
print("start XLA server in DistributedPhysicalMesh :",server_address)
time.sleep(2)

# Launch workers
workers = []
Expand Down Expand Up @@ -1117,6 +1122,7 @@ def launch_xla_servers(self):
self.mesh_id, move_worker,
global_config.runtime_random_seed)
workers.append(worker)
print(len(workers) ,"MeshHostWorkers are created")
return service_server, workers

@property
Expand Down Expand Up @@ -1783,6 +1789,12 @@ def prefetch(dis_arrays: Sequence[Union[ShardedDeviceArray, DistributedArray,
########################################
class VirtualPhysicalMesh:
"""
用于流水线并行编译的虚拟物理网格。VirtualPhysicalMesh 在编译时使用, 不会为它分配实际的 Worker。
编译完成后,我们将其实例化为 PhysicalDeviceMesh 并启动 Worker。
一个 VirtualPhysicalMesh 也可以被切成多个 VirtualPhysicalMesh。
切片后,每个被切片的 VirtualPhysicalMesh 都可以实例化为一个 PhysicalDeviceMesh。
这些被切分的物理设备网格可以共同组成一个物理设备网格组(PhysicalDeviceMeshGroup),
以实现流水线并行。
A virtual physical mesh used for pipeline parallel compilation.

VirtualPhysicalMesh is used during compile time. We don't allocate actual
Expand Down Expand Up @@ -2135,6 +2147,19 @@ def __init__(self,
ray_global_node = ray_worker._global_node
try:
self.head_info = ray_global_node.address_info
# {
# 'node_ip_address':'10.176.22.221',
# 'raylet_ip_address':'10.176.22.221',
# 'redis_address':None,
# 'object_store_address':'/tmp/ray/session_2023-12-16_12-55-25_647403_776800/sockets/plasma_store',
# 'raylet_socket_name':'/tmp/ray/session_2023-12-16_12-55-25_647403_776800/sockets/raylet',
# 'webui_url':'',
# 'session_dir':'/tmp/ray/session_2023-12-16_12-55-25_647403_776800',
# 'metrics_export_port':64426,
# 'gcs_address':'10.176.22.221:6379',
# 'address':'10.176.22.221:6379',
# 'dashboard_agent_listen_port':52365
# }
except AttributeError as ae:
raise RuntimeError(
"Cannot access ray global node. Did you call ray.init?") \
Expand All @@ -2151,14 +2176,17 @@ def __init__(self,
in node["Resources"]):
all_host_info.append(node)
all_host_ips.append(key.split("node:")[-1])

#print(all_host_info)[{'NodeID': 'xxx', 'Alive': True, 'NodeManagerAddress': '10.176.22.221', 'NodeManagerHostname': 'rdma221', 'NodeManagerPort': 34439, 'ObjectManagerPort': 46231, 'ObjectStoreSocketName': 'xx', 'RayletSocketName': 'xx', 'MetricsExportPort': 64426, 'NodeName': '10.176.22.221', 'alive': True, 'Resources': {'CPU': 64.0, 'object_store_memory': 72142757068.0, 'memory': 158333099828.0, 'node:10.176.22.221': 1.0, 'GPU': 2.0}},
#{'NodeID': 'xxx', 'Alive': True, 'NodeManagerAddress': '10.176.22.220', 'NodeManagerHostname': 'rdma220', 'NodeManagerPort': 36049, 'ObjectManagerPort': 37143, 'ObjectStoreSocketName': 'xx', 'RayletSocketName': 'xx', 'MetricsExportPort': 63691, 'NodeName': '10.176.22.220', 'alive': True, 'Resources': {'accelerator_type:G': 1.0, 'memory': 179180961792.0, 'object_store_memory': 76791840768.0, 'CPU': 64.0, 'GPU': 2.0, 'node:10.176.22.220': 1.0}}]

# Gather device info
all_host_num_devices = []
for host_info in all_host_info:
number = host_info["Resources"][global_config.ray_accelerator_name]
assert number.is_integer()
all_host_num_devices.append(int(number))

#print(all_host_num_devices)[2,2]
# adjust the resource allocations
# if `num_nodes` is set, use it.
# otherwise, use the number of nodes in cluster
Expand All @@ -2181,7 +2209,7 @@ def __init__(self,
self.host_num_devices = all_host_num_devices

# Create placement group
self.namespace = namespace
self.namespace =namespace #'None'
if namespace:
pg_name = namespace + "_pg"
try:
Expand Down
4 changes: 2 additions & 2 deletions alpa/global_env.py
Original file line number Diff line number Diff line change
Expand Up @@ -94,8 +94,8 @@ def __init__(self):
self.flax_always_use_fp16_embedding = False

########## Options of logging ##########
self.print_compilation_time = False
self.print_auto_layer_stats = False
self.print_compilation_time = True
self.print_auto_layer_stats = True

# Whether to collect activity trace
self.collect_trace = False
Expand Down
6 changes: 4 additions & 2 deletions alpa/pipeline_parallel/compile_executable.py
Original file line number Diff line number Diff line change
Expand Up @@ -43,7 +43,7 @@
OrderedSet, GradFuncTransformContext)

logger = logging.getLogger(__name__)
logger.setLevel(logging.INFO)
logger.setLevel(logging.DEBUG)


def compile_pipeshard_executable(
Expand Down Expand Up @@ -114,7 +114,7 @@ def compile_pipeshard_executable(
batch_invars, virtual_mesh, num_microbatch, pipeline_schedule,
default_as_option, stage_option, name_base, global_input_shardings,
None, stage_input_shardings, parsed_ms_option)

print("create PipeshardDriverExecutable")
executable = PipeshardDriverExecutable(
mesh_group=virtual_mesh.launched_physical_mesh_group,
pipeshard_config=pipeshard_config,
Expand Down Expand Up @@ -148,6 +148,7 @@ def compile_pipeshard_executable_internal(
stage_input_shardings: Forcibly set sharding specs of input vars of
each stage.
"""
print("call internal compile")
global_invars = closed_jaxpr.jaxpr.invars
gensym_func = gensym([closed_jaxpr.jaxpr])
inference_mode = (pipeline_schedule == "inference")
Expand Down Expand Up @@ -292,6 +293,7 @@ def compile_pipeshard_executable_internal(
pipeshard_config = emitter_cls(**emitter_kwargs).compile()

debug_compilation_time("runtime emitter")
print("internal compile over")
return pipeshard_config


Expand Down
4 changes: 2 additions & 2 deletions alpa/pipeline_parallel/pipeshard_executable.py
Original file line number Diff line number Diff line change
Expand Up @@ -35,7 +35,7 @@
traceback_util.register_exclusion(__file__)

logger = logging.getLogger(__name__)
logger.setLevel(logging.INFO)
logger.setLevel(logging.DEBUG)


class PipeshardDriverExecutable:
Expand All @@ -56,7 +56,7 @@ def __init__(self,
self.in_tree = in_tree
self.out_tree = out_tree
self.static_argnums = static_argnums

print("init PipeshardDriverExecutable",self.num_mesh,"meshes")
##### For debugging and serialization #####
self.stages = pipeshard_config.xla_stages
self.schedule = pipeshard_config.schedule
Expand Down
2 changes: 1 addition & 1 deletion alpa/pipeline_parallel/stage_construction.py
Original file line number Diff line number Diff line change
Expand Up @@ -623,7 +623,7 @@ def cluster_layers_and_slice_mesh(
submesh_choices = get_submesh_choices(
virtual_mesh.num_hosts, virtual_mesh.num_devices_per_host,
stage_option.submesh_physical_shape_space,
stage_option.manually_specified_submeshes)
stage_option.manually_specified_submeshes)#(1,1) (1,2)....(1,M),(2,M)...(N,M)
autosharding_configs = get_all_submesh_autosharding_config_choices(
virtual_mesh, submesh_choices,
stage_option.submesh_logical_shape_space, batch_size)
Expand Down
26 changes: 19 additions & 7 deletions alpa/pipeline_parallel/stage_profiling.py
Original file line number Diff line number Diff line change
Expand Up @@ -293,8 +293,8 @@ class CompileWorkerPool(BaseWorkerPoolWrapper):

def __init__(self, num_cpus, debug_mode=False):
super().__init__()
worker_cls = ray.remote(num_cpus=1)(CompileWorker)
self.actors = [worker_cls.remote() for _ in range(num_cpus)]
worker_cls = ray.remote(num_cpus=1)(CompileWorker)#actor class
self.actors = [worker_cls.remote() for _ in range(num_cpus)]#get remote actor handles
self.pool = ActorPool(self.actors)
self.local_worker = CompileWorker() if debug_mode else None

Expand Down Expand Up @@ -342,6 +342,7 @@ def _profile_impl(self, stage_id, compiled_module_output, stage_plan,
peak memory during the computation, the total available memory,
the input intermediate size and input initial size.
"""
print("profiling stage: ",stage_id)
input_avals = profile_config.invar_avals
output_avals = profile_config.outvar_avals
donated_invars = profile_config.donated_invars
Expand All @@ -357,12 +358,17 @@ def _profile_impl(self, stage_id, compiled_module_output, stage_plan,
stage_plan, input_avals,
output_avals,
donated_invars)
print("get executable")

# Run profiling
self.mesh.reset_memory_stats()
print("reset memory")
peak_memory = executable.get_total_allocation_size()
print("get peak memory",peak_memory)
available_memory = self.mesh.get_available_memory()
cost = executable.profile_with_dummy_inputs(skip_grad_sync=True)
print("get available memory",available_memory)
cost = executable.profile_with_dummy_inputs(skip_grad_sync=True)#problem
print("test dummy cost",cost)
del executable

return stage_id, cost, peak_memory, available_memory
Expand Down Expand Up @@ -393,6 +399,7 @@ def profile(self, stage_id, compiled_output, stage_plan, profile_info):

def restart(self, forced):
"""Restart the physical mesh."""
print("restart profile worker")
self.mesh.shutdown(forced=forced)
self.virtual_mesh.launched_physical_mesh = None
self.mesh = self.virtual_mesh.get_physical_mesh()
Expand All @@ -408,6 +415,7 @@ def __init__(self, virtual_meshes, placement_group):
worker_cls.options(placement_group=placement_group).remote(mesh)
for mesh in virtual_meshes
]
print("number of actors:", len(self.actors))
self.pool = ActorPool(self.actors)


Expand Down Expand Up @@ -596,10 +604,13 @@ def profile_all(stages, compiled_outputs: Sequence[CompileOutput], meshes,
prof_result,
mesh_num_devices,
num_micro_batches)
print("use HloCostModelProfileWorkerPool")
else:
print("use ProfileWorkerPool")
profile_workers = ProfileWorkerPool(meshes, placement_group)

successful_compile_ct = 0
print(len(stages)," is profiling")
for i, (compiled_output, stage) in enumerate(zip(compiled_outputs, stages)):
if compiled_output is None:
continue
Expand All @@ -614,12 +625,12 @@ def profile_all(stages, compiled_outputs: Sequence[CompileOutput], meshes,
((i, module_id), acc_grad_module,
compiled_output.stage_plan, profile_config))
successful_compile_ct += 1

pbar = tqdm.tqdm(range(successful_compile_ct))
for _ in pbar:
try:
((i, module_id),
*module_raw_result) = profile_workers.get_next_unordered()
*module_raw_result) = profile_workers.get_next_unordered()#problem
print(i,module_id)
except TimeoutError:
profile_workers.shutdown(force=True)
logger.warning("After waiting for too long, "
Expand Down Expand Up @@ -686,7 +697,7 @@ def generate_training_stages_2d(layers,
if apply_grad_layers[idx] is not None
]
stage_name = f"stage_{start}_{end}"
stage_config = generate_stage_info(
stage_config = generate_stage_info(#jax to hlo
layers, [forward_layer_indices, backward_layer_indices],
accumulator_mapping, acc_grad_invars, acc_grad_outvars,
stage_name, selected_apply_grad_layers, apply_grad_global_info)
Expand Down Expand Up @@ -1242,7 +1253,7 @@ def get_compute_cost(
).get_virtual_physical_mesh()
sliced_virtual_meshes = (
whole_cluster_virtual_mesh.slice_profiling_submeshes(
num_hosts, num_devices_per_host))
num_hosts, num_devices_per_host))#list of virtualPhysicalMesh
else:
sliced_virtual_meshes = virtual_mesh.slice_profiling_submeshes(
num_hosts, num_devices_per_host)
Expand All @@ -1257,6 +1268,7 @@ def get_compute_cost(
sliced_virtual_meshes[0].num_devices, cluster_size,
auto_stage_option.stage_imbalance_tolerance)
else:
#starge==>(stage_indices, stage_config, autosharding_config)
stages = generate_training_stages_2d(
layers, layer_flops_prefix_sum, accumulator_mapping,
acc_grad_invars, acc_grad_outvars, apply_grad_layers,
Expand Down
6 changes: 4 additions & 2 deletions alpa/util.py
Original file line number Diff line number Diff line change
Expand Up @@ -51,7 +51,7 @@
########################################

logger = logging.getLogger(__name__)

logger.setLevel(logging.DEBUG)

def freeze_dict(pytree: PyTreeDef):
"""Convert a pytree to a FrozenDict."""
Expand Down Expand Up @@ -1031,9 +1031,10 @@ def profile_xla_executable(compiled, backend, local_devices):

# Run benchmark
def run_func():
print("execute_sharded_on_local_devices ")
device_outputs = compiled.execute_sharded_on_local_devices(
device_inputs)

print("execute_sharded_on_local_devices over")
# Reset the value for donate buffers
ct = 0
for j in range(len(device_inputs)):
Expand All @@ -1046,6 +1047,7 @@ def run_func():
try:
costs = benchmark_func(run_func, repeat=3, number=3)
except RuntimeError:
print("error when running benchmark")
costs = cost_failed
return costs

Expand Down
18 changes: 18 additions & 0 deletions debug_out/compute_cost.txt
Original file line number Diff line number Diff line change
@@ -0,0 +1,18 @@
from,to,submush_choices_idx,autosharding config
[[[[0.00229896 0.00252581]
[ inf inf]]

[[0.00382972 0.00389926]
[0.00471987 0.00322694]]]


[[[ inf inf]
[ inf inf]]

[[0.00275302 0.00274181]
[ inf inf]]]]
Result forward_stage_layer_ids: [[0], [1]]
Result mesh_shapes: [(1, 1), (1, 1)]
Result logical_mesh_shapes: [(1, 1), (1, 1)]
Result autosharding_option_dicts: [{'force_batch_dim_to_mesh_dim': 0}, {}]

Loading