From b207f4c93b304b4c850ff3ad01f68f1ec7ff3688 Mon Sep 17 00:00:00 2001 From: Chris Schinnerl Date: Thu, 14 Nov 2024 16:05:23 +0100 Subject: [PATCH] api: add v2 siamux addresses to HostInfo --- api/host.go | 5 ++- stores/sql/main.go | 94 ++++++++++++++++++++++++++++++---------------- 2 files changed, 64 insertions(+), 35 deletions(-) diff --git a/api/host.go b/api/host.go index 91534ab26..bb3fe8595 100644 --- a/api/host.go +++ b/api/host.go @@ -116,8 +116,9 @@ type ( } HostInfo struct { - PublicKey types.PublicKey `json:"publicKey"` - SiamuxAddr string `json:"siamuxAddr"` + PublicKey types.PublicKey `json:"publicKey"` + SiamuxAddr string `json:"siamuxAddr"` + V2SiamuxAddresses []string `json:"v2SiamuxAddresses"` } HostInteractions struct { diff --git a/stores/sql/main.go b/stores/sql/main.go index 0343bbc06..63883daf0 100644 --- a/stores/sql/main.go +++ b/stores/sql/main.go @@ -877,39 +877,12 @@ func Hosts(ctx context.Context, tx sql.Tx, opts api.HostOptions) ([]api.Host, er } // fill in v2 addresses - netAddrsStmt, err := tx.Prepare(ctx, "SELECT ha.net_address, ha.protocol FROM host_addresses ha INNER JOIN hosts h ON ha.db_host_id = h.id WHERE h.id = ?") + err = fillInV2Addresses(ctx, tx, hostIDs, func(i int, addrs []string) { + hosts[i].V2SiamuxAddresses = addrs + i++ + }) if err != nil { - return nil, fmt.Errorf("failed to prepare stmt for fetching host addresses: %w", err) - } - defer netAddrsStmt.Close() - - fetchAddrs := func(hostID int64) ([]chain.NetAddress, error) { - rows, err := netAddrsStmt.Query(ctx, hostID) - if err != nil { - return nil, err - } - defer rows.Close() - var addrs []chain.NetAddress - for rows.Next() { - var addr chain.NetAddress - if err := rows.Scan(&addr.Address, (*ChainProtocol)(&addr.Protocol)); err != nil { - return nil, err - } - addrs = append(addrs, addr) - } - return addrs, nil - } - - for i := range hosts { - netAddrs, err := fetchAddrs(hostIDs[i]) - if err != nil { - return nil, fmt.Errorf("failed to fetch net addresses for host %d: %w", hostIDs[i], err) - } - for _, na := range netAddrs { - if na.Protocol == rhp4.ProtocolTCPSiaMux { - hosts[i].V2SiamuxAddresses = append(hosts[i].V2SiamuxAddresses, na.Address) - } - } + return nil, err } // query host checks @@ -2392,6 +2365,7 @@ EXISTS ( // query hosts rows, err := tx.Query(ctx, fmt.Sprintf(` SELECT + h.id, h.public_key, COALESCE(h.net_address, ""), COALESCE(h.settings->>'$.siamuxport', "") AS siamux_port, @@ -2408,12 +2382,14 @@ EXISTS ( defer rows.Close() var hosts []HostInfo + var hostIDs []int64 for rows.Next() { + var hostID int64 var hk PublicKey var addr, port string var pt PriceTable var hs HostSettings - err := rows.Scan(&hk, &addr, &port, &pt, &hs) + err := rows.Scan(&hostID, &hk, &addr, &port, &pt, &hs) if err != nil { return nil, fmt.Errorf("failed to scan host: %w", err) } @@ -2432,7 +2408,18 @@ EXISTS ( rhpv2.HostSettings(hs), rhpv3.HostPriceTable(pt), }) + hostIDs = append(hostIDs, hostID) } + + // fill in v2 addresses + err = fillInV2Addresses(ctx, tx, hostIDs, func(i int, addrs []string) { + hosts[i].V2SiamuxAddresses = addrs + i++ + }) + if err != nil { + return nil, err + } + return hosts, nil } @@ -2804,6 +2791,47 @@ func Object(ctx context.Context, tx Tx, bucket, key string) (api.Object, error) }, nil } +func fillInV2Addresses(ctx context.Context, tx sql.Tx, hostIDs []int64, assignFn func(int, []string)) error { + // fill in v2 addresses + netAddrsStmt, err := tx.Prepare(ctx, "SELECT ha.net_address, ha.protocol FROM host_addresses ha INNER JOIN hosts h ON ha.db_host_id = h.id WHERE h.id = ?") + if err != nil { + return fmt.Errorf("failed to prepare stmt for fetching host addresses: %w", err) + } + defer netAddrsStmt.Close() + + fetchAddrs := func(hostID int64) ([]chain.NetAddress, error) { + rows, err := netAddrsStmt.Query(ctx, hostID) + if err != nil { + return nil, err + } + defer rows.Close() + var addrs []chain.NetAddress + for rows.Next() { + var addr chain.NetAddress + if err := rows.Scan(&addr.Address, (*ChainProtocol)(&addr.Protocol)); err != nil { + return nil, err + } + addrs = append(addrs, addr) + } + return addrs, nil + } + + for i, hostID := range hostIDs { + netAddrs, err := fetchAddrs(hostID) + if err != nil { + return fmt.Errorf("failed to fetch net addresses for host %d: %w", hostIDs[i], err) + } + var addrs []string + for _, addr := range netAddrs { + if addr.Protocol == rhp4.ProtocolTCPSiaMux { + addrs = append(addrs, addr.Address) + } + } + assignFn(i, addrs) + } + return nil +} + func listObjectsNoDelim(ctx context.Context, tx Tx, bucket, prefix, substring, sortBy, sortDir, marker string, limit int, slabEncryptionKey object.EncryptionKey) (api.ObjectsResponse, error) { // fetch one more to see if there are more entries if limit <= -1 {