Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Fix unmarshaling slice of objects in yaml config #786

Merged
merged 2 commits into from
Mar 25, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
66 changes: 66 additions & 0 deletions internal/tools/config.go
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,8 @@ import (

"github.com/FZambia/viper-lite"
"github.com/google/uuid"
"github.com/mitchellh/mapstructure"
"github.com/rs/zerolog/log"
)

// pathExists returns whether the given file or directory exists or not
Expand Down Expand Up @@ -168,3 +170,67 @@ func OptionalStringChoice(v *viper.Viper, key string, choices []string) (string,
}
return val, nil
}

// DecoderConfig returns default mapstructure.DecoderConfig with support
// of time.Duration values & string slices & Duration
func DecoderConfig(output any) *mapstructure.DecoderConfig {
return &mapstructure.DecoderConfig{
Metadata: nil,
Result: output,
WeaklyTypedInput: true,
DecodeHook: mapstructure.ComposeDecodeHookFunc(
StringToDurationHookFunc(),
mapstructure.StringToTimeDurationHookFunc(),
mapstructure.StringToSliceHookFunc(","),
),
}
}

func DecodeSlice(v *viper.Viper, dst any, key string) []byte {
var jsonData []byte
var err error
switch val := v.Get(key).(type) {
case string:
jsonData = []byte(val)
err = json.Unmarshal([]byte(val), dst)
case []any:
jsonData, err = json.Marshal(translateMap(val))
if err != nil {
log.Fatal().Err(err).Msgf("error marshalling config %s slice", key)
}
decoderCfg := DecoderConfig(dst)
decoder, newErr := mapstructure.NewDecoder(decoderCfg)
if newErr != nil {
log.Fatal().Msg(newErr.Error())
}
err = decoder.Decode(v.Get(key))
default:
err = fmt.Errorf("unknown %s type: %T", key, val)
}
if err != nil {
log.Fatal().Err(err).Msgf("malformed %s", key)
}
return jsonData
}

// translateMap is a helper to deal with map[any]any which YAML uses when unmarshalling.
// We always use string keys and not making this transform results into errors on JSON marshaling.
func translateMap(input []any) []map[string]any {
var result []map[string]any
for _, elem := range input {
switch v := elem.(type) {
case map[any]any:
translatedMap := make(map[string]any)
for key, value := range v {
stringKey := fmt.Sprintf("%v", key)
translatedMap[stringKey] = value
}
result = append(result, translatedMap)
case map[string]any:
result = append(result, v)
default:
log.Fatal().Msgf("invalid type in slice: %T", elem)
}
}
return result
}
15 changes: 0 additions & 15 deletions internal/tools/duration.go
Original file line number Diff line number Diff line change
Expand Up @@ -10,21 +10,6 @@ import (
"github.com/mitchellh/mapstructure"
)

// DecoderConfig returns default mapstructure.DecoderConfig with support
// of time.Duration values & string slices & Duration
func DecoderConfig(output any) *mapstructure.DecoderConfig {
return &mapstructure.DecoderConfig{
Metadata: nil,
Result: output,
WeaklyTypedInput: true,
DecodeHook: mapstructure.ComposeDecodeHookFunc(
StringToDurationHookFunc(),
mapstructure.StringToTimeDurationHookFunc(),
mapstructure.StringToSliceHookFunc(","),
),
}
}

type Duration time.Duration

// StringToDurationHookFunc returns a DecodeHookFunc that converts
Expand Down
96 changes: 6 additions & 90 deletions main.go
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,6 @@ package main
import (
"context"
"crypto/tls"
"encoding/json"
"errors"
"fmt"
stdlog "log"
Expand Down Expand Up @@ -71,7 +70,6 @@ import (
"github.com/centrifugal/centrifuge"
"github.com/justinas/alice"
"github.com/mattn/go-isatty"
"github.com/mitchellh/mapstructure"
"github.com/prometheus/client_golang/prometheus"
"github.com/prometheus/client_golang/prometheus/promhttp"
"github.com/quic-go/quic-go/http3"
Expand Down Expand Up @@ -1758,28 +1756,7 @@ func rpcNamespacesFromConfig(v *viper.Viper) []rule.RpcNamespace {
if !v.IsSet("rpc_namespaces") {
return ns
}
var jsonData []byte
var err error
switch val := v.Get("rpc_namespaces").(type) {
case string:
jsonData, _ = json.Marshal(val)
err = json.Unmarshal([]byte(val), &ns)
case []any:
jsonData, _ = json.Marshal(val)
decoderCfg := tools.DecoderConfig(&ns)
decoder, newErr := mapstructure.NewDecoder(decoderCfg)
if newErr != nil {
log.Fatal().Msg(newErr.Error())
return ns
}
err = decoder.Decode(v.Get("rpc_namespaces"))
default:
err = fmt.Errorf("unknown rpc_namespaces type: %T", val)
}
if err != nil {
log.Error().Err(err).Msg("malformed rpc_namespaces")
os.Exit(1)
}
jsonData := tools.DecodeSlice(v, &ns, "rpc_namespaces")
rule.WarnUnknownRpcNamespaceKeys(jsonData)
return ns
}
Expand All @@ -1790,28 +1767,7 @@ func namespacesFromConfig(v *viper.Viper) []rule.ChannelNamespace {
if !v.IsSet("namespaces") {
return ns
}
var jsonData []byte
var err error
switch val := v.Get("namespaces").(type) {
case string:
jsonData = []byte(val)
err = json.Unmarshal([]byte(val), &ns)
case []any:
jsonData, _ = json.Marshal(val)
decoderCfg := tools.DecoderConfig(&ns)
decoder, newErr := mapstructure.NewDecoder(decoderCfg)
if newErr != nil {
log.Fatal().Msg(newErr.Error())
return ns
}
err = decoder.Decode(v.Get("namespaces"))
default:
err = fmt.Errorf("unknown namespaces type: %T", val)
}
if err != nil {
log.Error().Err(err).Msg("malformed namespaces")
os.Exit(1)
}
jsonData := tools.DecodeSlice(v, &ns, "namespaces")
rule.WarnUnknownNamespaceKeys(jsonData)
return ns
}
Expand All @@ -1824,27 +1780,9 @@ func granularProxiesFromConfig(v *viper.Viper) []proxy.Config {
if !v.IsSet("proxies") {
return proxies
}
var jsonData []byte
var err error
switch val := v.Get("proxies").(type) {
case string:
jsonData = []byte(val)
err = json.Unmarshal([]byte(val), &proxies)
case []any:
jsonData, _ = json.Marshal(val)
decoderCfg := tools.DecoderConfig(&proxies)
decoder, newErr := mapstructure.NewDecoder(decoderCfg)
if newErr != nil {
log.Fatal().Msg(newErr.Error())
return proxies
}
err = decoder.Decode(v.Get("proxies"))
default:
err = fmt.Errorf("unknown proxies type: %T", val)
}
if err != nil {
log.Fatal().Err(err).Msg("malformed proxies")
}
jsonData := tools.DecodeSlice(v, &proxies, "proxies")
proxy.WarnUnknownProxyKeys(jsonData)

names := map[string]struct{}{}
for _, p := range proxies {
if !proxyNameRe.Match([]byte(p.Name)) {
Expand All @@ -1862,8 +1800,6 @@ func granularProxiesFromConfig(v *viper.Viper) []proxy.Config {
names[p.Name] = struct{}{}
}

proxy.WarnUnknownProxyKeys(jsonData)

return proxies
}

Expand All @@ -1873,27 +1809,7 @@ func consumersFromConfig(v *viper.Viper) []consuming.ConsumerConfig {
if !v.IsSet("consumers") {
return consumers
}
var jsonData []byte
var err error
switch val := v.Get("consumers").(type) {
case string:
jsonData, _ = json.Marshal(val)
err = json.Unmarshal([]byte(val), &consumers)
case []any:
jsonData, _ = json.Marshal(val)
decoderCfg := tools.DecoderConfig(&consumers)
decoder, newErr := mapstructure.NewDecoder(decoderCfg)
if newErr != nil {
log.Fatal().Msg(newErr.Error())
}
err = decoder.Decode(v.Get("consumers"))
default:
err = fmt.Errorf("unknown consumers type: %T", val)
}
if err != nil {
log.Error().Err(err).Msg("malformed consumers")
os.Exit(1)
}
jsonData := tools.DecodeSlice(v, &consumers, "consumers")
consuming.WarnUnknownConsumerConfigKeys(jsonData)
return consumers
}
Expand Down
Loading