Skip to content

Commit

Permalink
api: add v2 siamux addresses to HostInfo
Browse files Browse the repository at this point in the history
  • Loading branch information
ChrisSchinnerl committed Nov 14, 2024
1 parent 98e9c47 commit b207f4c
Show file tree
Hide file tree
Showing 2 changed files with 64 additions and 35 deletions.
5 changes: 3 additions & 2 deletions api/host.go
Original file line number Diff line number Diff line change
Expand Up @@ -116,8 +116,9 @@ type (
}

HostInfo struct {
PublicKey types.PublicKey `json:"publicKey"`
SiamuxAddr string `json:"siamuxAddr"`
PublicKey types.PublicKey `json:"publicKey"`
SiamuxAddr string `json:"siamuxAddr"`
V2SiamuxAddresses []string `json:"v2SiamuxAddresses"`
}

HostInteractions struct {
Expand Down
94 changes: 61 additions & 33 deletions stores/sql/main.go
Original file line number Diff line number Diff line change
Expand Up @@ -877,39 +877,12 @@ func Hosts(ctx context.Context, tx sql.Tx, opts api.HostOptions) ([]api.Host, er
}

// fill in v2 addresses
netAddrsStmt, err := tx.Prepare(ctx, "SELECT ha.net_address, ha.protocol FROM host_addresses ha INNER JOIN hosts h ON ha.db_host_id = h.id WHERE h.id = ?")
err = fillInV2Addresses(ctx, tx, hostIDs, func(i int, addrs []string) {
hosts[i].V2SiamuxAddresses = addrs
i++
})
if err != nil {
return nil, fmt.Errorf("failed to prepare stmt for fetching host addresses: %w", err)
}
defer netAddrsStmt.Close()

fetchAddrs := func(hostID int64) ([]chain.NetAddress, error) {
rows, err := netAddrsStmt.Query(ctx, hostID)
if err != nil {
return nil, err
}
defer rows.Close()
var addrs []chain.NetAddress
for rows.Next() {
var addr chain.NetAddress
if err := rows.Scan(&addr.Address, (*ChainProtocol)(&addr.Protocol)); err != nil {
return nil, err
}
addrs = append(addrs, addr)
}
return addrs, nil
}

for i := range hosts {
netAddrs, err := fetchAddrs(hostIDs[i])
if err != nil {
return nil, fmt.Errorf("failed to fetch net addresses for host %d: %w", hostIDs[i], err)
}
for _, na := range netAddrs {
if na.Protocol == rhp4.ProtocolTCPSiaMux {
hosts[i].V2SiamuxAddresses = append(hosts[i].V2SiamuxAddresses, na.Address)
}
}
return nil, err
}

// query host checks
Expand Down Expand Up @@ -2392,6 +2365,7 @@ EXISTS (
// query hosts
rows, err := tx.Query(ctx, fmt.Sprintf(`
SELECT
h.id,
h.public_key,
COALESCE(h.net_address, ""),
COALESCE(h.settings->>'$.siamuxport', "") AS siamux_port,
Expand All @@ -2408,12 +2382,14 @@ EXISTS (
defer rows.Close()

var hosts []HostInfo
var hostIDs []int64
for rows.Next() {
var hostID int64
var hk PublicKey
var addr, port string
var pt PriceTable
var hs HostSettings
err := rows.Scan(&hk, &addr, &port, &pt, &hs)
err := rows.Scan(&hostID, &hk, &addr, &port, &pt, &hs)
if err != nil {
return nil, fmt.Errorf("failed to scan host: %w", err)
}
Expand All @@ -2432,7 +2408,18 @@ EXISTS (
rhpv2.HostSettings(hs),
rhpv3.HostPriceTable(pt),
})
hostIDs = append(hostIDs, hostID)
}

// fill in v2 addresses
err = fillInV2Addresses(ctx, tx, hostIDs, func(i int, addrs []string) {
hosts[i].V2SiamuxAddresses = addrs
i++
})
if err != nil {
return nil, err
}

return hosts, nil
}

Expand Down Expand Up @@ -2804,6 +2791,47 @@ func Object(ctx context.Context, tx Tx, bucket, key string) (api.Object, error)
}, nil
}

func fillInV2Addresses(ctx context.Context, tx sql.Tx, hostIDs []int64, assignFn func(int, []string)) error {
// fill in v2 addresses
netAddrsStmt, err := tx.Prepare(ctx, "SELECT ha.net_address, ha.protocol FROM host_addresses ha INNER JOIN hosts h ON ha.db_host_id = h.id WHERE h.id = ?")
if err != nil {
return fmt.Errorf("failed to prepare stmt for fetching host addresses: %w", err)
}
defer netAddrsStmt.Close()

fetchAddrs := func(hostID int64) ([]chain.NetAddress, error) {
rows, err := netAddrsStmt.Query(ctx, hostID)
if err != nil {
return nil, err
}
defer rows.Close()
var addrs []chain.NetAddress
for rows.Next() {
var addr chain.NetAddress
if err := rows.Scan(&addr.Address, (*ChainProtocol)(&addr.Protocol)); err != nil {
return nil, err
}
addrs = append(addrs, addr)
}
return addrs, nil
}

for i, hostID := range hostIDs {
netAddrs, err := fetchAddrs(hostID)
if err != nil {
return fmt.Errorf("failed to fetch net addresses for host %d: %w", hostIDs[i], err)
}
var addrs []string
for _, addr := range netAddrs {
if addr.Protocol == rhp4.ProtocolTCPSiaMux {
addrs = append(addrs, addr.Address)
}
}
assignFn(i, addrs)
}
return nil
}

func listObjectsNoDelim(ctx context.Context, tx Tx, bucket, prefix, substring, sortBy, sortDir, marker string, limit int, slabEncryptionKey object.EncryptionKey) (api.ObjectsResponse, error) {
// fetch one more to see if there are more entries
if limit <= -1 {
Expand Down

0 comments on commit b207f4c

Please sign in to comment.