Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Construct "real" topology #168

Open
wants to merge 2 commits into
base: master
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
182 changes: 122 additions & 60 deletions scripts/conf.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,7 @@
# See the License for the specific language governing permissions and
# limitations under the License.

from typing import List, Optional, Iterable
from typing import List, Optional, Iterable, Dict
from itertools import chain
from collections import defaultdict
import json
Expand Down Expand Up @@ -398,95 +398,157 @@ def install_gres_conf(lkp: util.Lookup) -> None:
class Switch:
"""
Represents a switch in the topology.conf file.
NOTE: It's class user job to make sure that there is no leaf-less Switches in the tree
"""

def __init__(
self,
name: str,
# TODO: consider using an Iterable instead of list to save memory
# That would required more elaborate behavior of `self.empty()`
nodes: Optional[List[str]] = None,
switches: Optional[List["Switch"]] = None,
nodes: Optional[List[str]] = None, # TODO: consider using generators
switches: Optional[Dict[str, "Switch"]] = None,
):
self.name = name
self.nodes = nodes or []
self.switches = switches or []
self.switches = switches or {}

def conf_line(self) -> str:
d = {"SwitchName": self.name}
if self.nodes:
d["Nodes"] = util.to_hostlist_fast(self.nodes)
if self.switches:
d["Switches"] = util.to_hostlist_fast(self.switches.keys())
return dict_to_conf(d)

non_empty = [
s for s in self.switches if not s.empty()
] # render only non-empty sub switches
if non_empty:
d["Switches"] = util.to_hostlist_fast([s.name for s in non_empty])
def render_conf_lines(self) -> Iterable[str]:
yield self.conf_line()
for s in sorted(self.switches.values(), key=lambda s: s.name):
yield from s.render_conf_lines()

return dict_to_conf(d)

def render_conf_lines(self) -> List[str]:
if self.empty():
class Topology:
def __init__(self) -> None:
self._r = Switch("") # fake root, not part of the tree

def add(self, path: List[str], nodes: Iterable[str]) -> None:
n = self._r
assert path
for p in path:
n = n.switches.setdefault(p, Switch(p))
n.nodes = [*n.nodes, *nodes]

def render_conf_lines(self) -> Iterable[str]:
if not self._r.switches:
return []
for s in sorted(self._r.switches.values(), key=lambda s: s.name):
yield from s.render_conf_lines()

lines = [self.conf_line()]
for s in sorted(self.switches, key=lambda s: s.name):
lines.extend(s.render_conf_lines())
return lines
def compress(self) -> "Topology":
compressed = Topology()

def empty(self) -> bool:
if self.nodes:
return False
return not any(not c.empty() for c in self.switches)
def _walk(
u: Switch, c: Switch
): # u: uncompressed node, c: counterpart in compressed tree
pref = f"{c.name}_" if c != compressed._r else "s"
for i, us in enumerate(sorted(u.switches.values(), key=lambda s: s.name)):
cs = Switch(f"{pref}{i}", nodes=us.nodes)
c.switches[cs.name] = cs
_walk(us, cs)

_walk(self._r, compressed._r)
return compressed


def tpu_nodeset_switch(nodeset: object, lkp: util.Lookup) -> Switch:
def add_tpu_nodeset_topology(nodeset: object, bldr: Topology, lkp: util.Lookup):
tpuobj = util.TPU(nodeset)
static, dynamic = lkp.nodenames(nodeset)

switch = Switch(name=nodeset.nodeset_name)
switch_name = f"ns_{nodeset.nodeset_name}"
pref = ["tpu-root", switch_name]
if tpuobj.vmcount == 1: # Put all nodes in one switch
switch.nodes = list(chain(static, dynamic))
else:
# Chunk nodes into sub-switches of size `vmcount`
for nodenames in (static, dynamic):
for nodeschunk in util.chunked(nodenames, n=tpuobj.vmcount):
sub_switch = Switch(
name=f"{switch.name}-{len(switch.switches)}",
nodes=list(nodeschunk),
)
switch.switches.append(sub_switch)
return switch


def nodeset_switch(nodeset: object, lkp: util.Lookup) -> Switch:
return Switch(name=nodeset.nodeset_name, nodes=list(chain(*lkp.nodenames(nodeset))))


def gen_topology(lkp: util.Lookup) -> List[Switch]:
# Returns a list of "root" switches
tpu_root = Switch(
name="nodeset_tpu-root",
switches=[tpu_nodeset_switch(ns, lkp) for ns in lkp.cfg.nodeset_tpu.values()],
)

ord_root = Switch(
name="nodeset-root",
switches=[nodeset_switch(ns, lkp) for ns in lkp.cfg.nodeset.values()],
)

return [tpu_root, ord_root]
bldr.add(pref, list(chain(static, dynamic)))
return

# Chunk nodes into sub-switches of size `vmcount`
chunk_num = 0
for nodenames in (static, dynamic):
for nodeschunk in util.chunked(nodenames, n=tpuobj.vmcount):
chunk_name = f"{switch_name}-{chunk_num}"
chunk_num += 1
bldr.add([*pref, chunk_name], list(nodeschunk))


def add_nodeset_phony_topology(
nodeset: object, topo: Topology, lkp: util.Lookup
) -> None:
path = ["slurm-root", f"ns_{nodeset.nodeset_name}"]
nodes = list(chain(*lkp.nodenames(nodeset)))
topo.add(path, nodes)


def _make_physical_path(inst: object) -> List[str]:
root, zone = "slurm-root", f"zone_{inst.zone}"
physical_host = inst.resourceStatus.get("physicalHost")

if not physical_host:
padding = [f"{inst.name}_pad{i}" for i in reversed(range(3))]
return [root, zone, *padding]
assert physical_host.startswith("/"), f"Unexpected physicalHost: {physical_host}"
parts = physical_host[1:].split("/")
if len(parts) >= 4:
return [root, *parts]
elif len(parts) == 3:
# TODO: parts[0] = placement_id + parts[0]
return [root, zone, *parts]
raise ValueError(f"Unexpected physicalHost: {physical_host}")


def add_nodeset_real_topology(
nodeset: object, topo: Topology, lkp: util.Lookup
) -> None:
real_nodes = set()
for inst in lkp.instances().values():
try:
if lkp.node_nodeset_name(inst.name) != nodeset.nodeset_name:
continue
except Exception:
continue # fail to lookup nodeset

topo.add(_make_physical_path(inst), [inst.name])
real_nodes.add(inst.name)

# Add phony nodes to the topology
phony_nodes = []
for node in chain(*lkp.nodenames(nodeset)):
if node not in real_nodes:
phony_nodes.append(node)
if phony_nodes:
topo.add(["slurm-root", f"ns_{nodeset.nodeset_name}"], phony_nodes)


def gen_topology(lkp: util.Lookup) -> Topology:
topo = Topology()
for ns in lkp.cfg.nodeset.values():
if ns.real_topology:
add_nodeset_real_topology(ns, topo, lkp)
else:
add_nodeset_phony_topology(ns, topo, lkp)
for ns in lkp.cfg.nodeset_dyn.values():
add_nodeset_real_topology(ns, topo, lkp)
for ns in lkp.cfg.nodeset_tpu.values():
add_tpu_nodeset_topology(ns, topo, lkp)
return topo


def gen_topology_conf(lkp: util.Lookup) -> None:
"""generate slurm topology.conf from config.yaml"""
lines = [FILE_PREAMBLE]
for r in gen_topology(lkp):
lines.extend(r.render_conf_lines())
lines.append("\n")

topo = gen_topology(lkp).compress()
conf_file = Path(lkp.cfg.output_dir or slurmdirs.etc) / "cloud_topology.conf"
conf_file.write_text("\n".join(lines))
with open(conf_file, "w") as f:
f.writelines(FILE_PREAMBLE + "\n")
for line in topo.render_conf_lines():
f.write(line)
f.write("\n")
f.write("\n")
util.chown_slurm(conf_file, mode=0o600)


Expand Down
Loading