Skip to content

Commit

Permalink
feat: 支持配置多个s3文件
Browse files Browse the repository at this point in the history
  • Loading branch information
jimyag committed Nov 25, 2024
1 parent b74f8ce commit 38671d8
Show file tree
Hide file tree
Showing 9 changed files with 141 additions and 58 deletions.
1 change: 1 addition & 0 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -253,6 +253,7 @@ read from s3 bucket

``` bash
cat s3.toml
[[s3]]
region = "xxxx"
access_key = "ak"
secret_key = "sk"
Expand Down
9 changes: 6 additions & 3 deletions cmd/cat.go
Original file line number Diff line number Diff line change
Expand Up @@ -30,7 +30,8 @@ func init() {
func catRun(cmd *cobra.Command, args []string) {
rdrs, err := getReaders(args)
if err != nil {
log.Panic(err).Msg("error getting readers")
log.Error(err).Msg("error getting readers")
return
}
for _, rdr := range rdrs {
if count == 0 {
Expand All @@ -44,7 +45,8 @@ func catRun(cmd *cobra.Command, args []string) {
for c := range rdr.MetaData().Schema.NumColumns() {
col, err := rgr.Column(c)
if err != nil {
log.Panic(err).Int("column", c).Msg("error getting column")
log.Error(err).Int("column", c).Msg("error getting column")
return
}
scanners[c] = dumper.NewDumper(col, convertInt96AsTime)
fields[c] = col.Descriptor().Path()
Expand Down Expand Up @@ -83,9 +85,10 @@ func catRun(cmd *cobra.Command, args []string) {
}
jsonVal, err := json.Marshal(val)
if err != nil {
log.Panic(err).
log.Error(err).
Str("val", fmt.Sprintf("%+v", val)).
Msg("error marshalling json")
return
}
fmt.Printf("%q: %s", fields[idx], jsonVal)
}
Expand Down
6 changes: 4 additions & 2 deletions cmd/diff.go
Original file line number Diff line number Diff line change
Expand Up @@ -22,11 +22,13 @@ func init() {

func diffRun(cmd *cobra.Command, args []string) {
if len(args) != 2 {
log.Panic().Msg("diff requires two parquet files")
log.Error().Msg("diff requires two parquet files")
return
}
rdrs, err := getReaders(args)
if err != nil {
log.Panic(err).Msg("error getting readers")
log.Error(err).Msg("error getting readers")
return
}
rdr1 := rdrs[0]
rdr2 := rdrs[1]
Expand Down
6 changes: 4 additions & 2 deletions cmd/footer.go
Original file line number Diff line number Diff line change
Expand Up @@ -22,13 +22,15 @@ func init() {
func footer(cmd *cobra.Command, args []string) {
rdrs, err := getReaders(args)
if err != nil {
log.Panic(err).Msg("error getting readers")
log.Error().Msgf("error getting readers: %s", err)
return
}
for _, rdr := range rdrs {
fileMetadata := rdr.MetaData()
m, err := json.MarshalIndent(fileMetadata, "", " ")
if err != nil {
log.Panic(err).Msg("error marshalling file metadata")
log.Error().Msgf("error marshalling file metadata: %s", err)
return
}
fmt.Println(string(m))
}
Expand Down
9 changes: 6 additions & 3 deletions cmd/meta.go
Original file line number Diff line number Diff line change
Expand Up @@ -27,7 +27,8 @@ func init() {
func meta(cmd *cobra.Command, args []string) {
rdrs, err := getReaders(args)
if err != nil {
log.Panic(err).Msg("error getting readers")
log.Error(err).Msg("error getting readers")
return
}
for i, rdr := range rdrs {
fileMetadata := rdr.MetaData()
Expand Down Expand Up @@ -100,12 +101,14 @@ func meta(cmd *cobra.Command, args []string) {
descRecord := fileMetadata.Schema.Column(c)
row = append(row, descRecord.Name())
if err != nil {
log.Panic(err).Msg("error getting column chunk metadata")
log.Error(err).Msg("error getting column chunk metadata")
return
}
if set, _ := chunkMeta.StatsSet(); set {
stats, err := chunkMeta.Statistics()
if err != nil {
log.Panic(err).Msg("error getting column chunk statistics")
log.Error(err).Msg("error getting column chunk statistics")
return
}
row = append(row, fmt.Sprint(chunkMeta.NumValues()))
if stats.HasMinMax() {
Expand Down
149 changes: 108 additions & 41 deletions cmd/root.go
Original file line number Diff line number Diff line change
Expand Up @@ -8,13 +8,16 @@ import (
"fmt"
"net/url"
"os"
"path/filepath"
"slices"

"github.com/BurntSushi/toml"
"github.com/apache/arrow/go/v17/parquet/file"
"github.com/aws/aws-sdk-go/aws"
"github.com/aws/aws-sdk-go/aws/credentials"
"github.com/aws/aws-sdk-go/aws/session"
"github.com/aws/aws-sdk-go/service/s3"
"github.com/jimyag/log"
"github.com/spf13/cobra"

"github.com/jimyag/parquet-tools/internal/reader"
Expand All @@ -23,7 +26,6 @@ import (
var rootCmd = &cobra.Command{
Use: "parquet-tools",
Short: "Utility to inspect Parquet files",
Run: func(cmd *cobra.Command, args []string) {},
}

const (
Expand All @@ -33,22 +35,41 @@ const (
httpScheme = "http"
httpsScheme = "https"

s3ConfigFileUsage = `s3 config file format:
----- BEGIN S3 CONFIG -----
endpoint = "http://127.0.0.1:9000"
region = "us-east-1"
access_key = "ak"
secret_key = "sk"
disable_ssl = true
force_path_style = true
----- END S3 CONFIG -----
`
s3ConfigFileUsage = `
# BEGIN S3 CONFIG -----
[[s3]]
endpoint = "http://127.0.0.1:9000"
region = "us-east-1"
access_key = "ak1"
secret_key = "sk1"
disable_ssl = true
force_path_style = true
scopes = ["bucket1", "bucket2"]
[[s3]]
endpoint = "http://127.0.0.1:9000"
region = "us-east-1"
access_key = "ak2"
secret_key = "sk2"
disable_ssl = true
force_path_style = true
scopes = ["bucket3", "bucket4"]
# END S3 CONFIG -----
`
)

var s3ConfigFile string
var s3ConfigFile string = ".parquet-tools/s3.toml"

func init() {
rootCmd.PersistentFlags().StringVarP(&s3ConfigFile, "s3-config", "", "", "s3 config file")
home, err := os.UserHomeDir()
if err != nil {
home = "./"
}
s3ConfigFile = filepath.Join(home, s3ConfigFile)
if _, err := os.Stat(s3ConfigFile); os.IsNotExist(err) {
os.MkdirAll(filepath.Dir(s3ConfigFile), 0700)
os.WriteFile(s3ConfigFile, []byte(s3ConfigFileUsage), 0600)
}
rootCmd.PersistentFlags().StringVarP(&s3ConfigFile, "s3-config", "", s3ConfigFile, "s3 config file")
}

func Execute() {
Expand All @@ -58,13 +79,18 @@ func Execute() {
}
}

type s3Config struct {
Region string `toml:"region" json:"region"`
AccessKey string `toml:"access_key" json:"access_key"`
SecretKey string `toml:"secret_key" json:"secret_key"`
DisableSSL bool `toml:"disable_ssl" json:"disable_ssl"`
ForcePathStyle bool `toml:"force_path_style" json:"force_path_style"`
EndPoint string `toml:"endpoint" json:"endpoint"`
type Config struct {
S3 []s3Cfg `toml:"s3" json:"s3"`
}

type s3Cfg struct {
Region string `toml:"region" json:"region"`
AccessKey string `toml:"access_key" json:"access_key"`
SecretKey string `toml:"secret_key" json:"secret_key"`
DisableSSL bool `toml:"disable_ssl" json:"disable_ssl"`
ForcePathStyle bool `toml:"force_path_style" json:"force_path_style"`
EndPoint string `toml:"endpoint" json:"endpoint"`
Scopes []string `toml:"scopes" json:"scopes"`
}

func getReaders(filenames []string) ([]*file.Reader, error) {
Expand Down Expand Up @@ -96,33 +122,74 @@ func getReaders(filenames []string) ([]*file.Reader, error) {
continue
}
if u.Scheme == s3Scheme || u.Scheme == s3aScheme {
cfg := s3Config{}
if s3ConfigFile == "" {
return nil, fmt.Errorf("s3 config file is required for s3 scheme")
}
cfg := Config{}
if _, err := toml.DecodeFile(s3ConfigFile, &cfg); err != nil {
return nil, err
}
mySession := session.Must(session.NewSession(&aws.Config{
Credentials: credentials.NewStaticCredentials(cfg.AccessKey, cfg.SecretKey, ""),
Endpoint: aws.String(cfg.EndPoint),
Region: aws.String(cfg.Region),
DisableSSL: aws.Bool(cfg.DisableSSL),
S3ForcePathStyle: aws.Bool(cfg.ForcePathStyle),
}))
s3Cli := s3.New(mySession)
s3Reader, err := reader.NewS3Reader(context.Background(), filename, s3Cli)
if err != nil {
return nil, err
for ic, c := range cfg.S3 {
mySession := session.Must(session.NewSession(&aws.Config{
Credentials: credentials.NewStaticCredentials(c.AccessKey, c.SecretKey, ""),
Endpoint: aws.String(c.EndPoint),
Region: aws.String(c.Region),
DisableSSL: aws.Bool(c.DisableSSL),
S3ForcePathStyle: aws.Bool(c.ForcePathStyle),
}))
bucket, _, err := reader.ParsePath(filename)
if err != nil {
return nil, err
}
s3Cli := s3.New(mySession)
_, err = reader.Stat(context.Background(), filename, s3Cli)
if err == nil {
s3Reader, err := reader.NewS3Reader(context.Background(), filename, s3Cli)
if err != nil {
return nil, err
}
rdr, err := file.NewParquetReader(s3Reader)
if err != nil {
return nil, err
}
readers[i] = rdr
if c.Scopes == nil {
c.Scopes = []string{}
}
if slices.Contains(c.Scopes, bucket) {
continue
}
c.Scopes = append(c.Scopes, bucket)
cfg.S3[ic] = c
// update config file
f, err := os.OpenFile(s3ConfigFile, os.O_WRONLY, 0600)
if err != nil {
log.Error().Msgf("error opening s3 config file: %s", err)
return nil, err
}
if err := toml.NewEncoder(f).Encode(cfg); err != nil {
log.Error().Msgf("error encoding s3 config file: %s", err)
return nil, err
}
f.Close()
break
}
for _, scope := range c.Scopes {
if scope == bucket {
s3Reader, err := reader.NewS3Reader(context.Background(), filename, s3Cli)
if err != nil {
return nil, err
}
rdr, err := file.NewParquetReader(s3Reader)
if err != nil {
return nil, err
}
readers[i] = rdr
break
}
}
}
rdr, err := file.NewParquetReader(s3Reader)
if err != nil {
return nil, err
if readers[i] == nil {
return nil, fmt.Errorf("don't have access to %s", filename)
}
readers[i] = rdr
continue
}
return nil, fmt.Errorf("unsupported scheme: %s", u.Scheme)
}
return readers, nil
}
3 changes: 2 additions & 1 deletion cmd/schema.go
Original file line number Diff line number Diff line change
Expand Up @@ -21,7 +21,8 @@ func init() {
func schemaRun(cmd *cobra.Command, args []string) {
rdrs, err := getReaders(args)
if err != nil {
log.Panic(err).Msg("error getting readers")
log.Error(err).Msg("error getting readers")
return
}
for _, rdr := range rdrs {
schema.PrintSchema(rdr.MetaData().Schema.Root(), os.Stdout, 2)
Expand Down
3 changes: 2 additions & 1 deletion cmd/struct.go
Original file line number Diff line number Diff line change
Expand Up @@ -27,7 +27,8 @@ func init() {
func structRun(cmd *cobra.Command, args []string) {
rdrs, err := getReaders(args)
if err != nil {
log.Panic(err).Msg("error getting readers")
log.Error(err).Msg("error getting readers")
return
}
for _, rdr := range rdrs {
parquetSchema := rdr.MetaData().Schema.Root()
Expand Down
13 changes: 8 additions & 5 deletions internal/reader/reader_s3.go
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,7 @@ import (
"github.com/aws/aws-sdk-go/aws"
"github.com/aws/aws-sdk-go/service/s3"
"github.com/aws/aws-sdk-go/service/s3/s3iface"
"github.com/jimyag/log"
)

var (
Expand All @@ -32,13 +33,15 @@ type S3Reader struct {
body io.ReadCloser
}

func parsePath(path string) (bucket, key string, err error) {
func ParsePath(path string) (bucket, key string, err error) {
u, err := url.Parse(path)
if err != nil {
return
}
if u.Scheme != "s3" && u.Scheme != "s3a" {
log.Error().Msgf("invalid s3 path: %s", path)
err = ErrInvalidS3Path
return
}
bucket = u.Host
key = strings.TrimPrefix(u.Path, "/")
Expand All @@ -59,8 +62,8 @@ func head(ctx context.Context, bucket, key string, client s3iface.S3API) (*s3.He
})
}

func stat(ctx context.Context, uri string, client s3iface.S3API) (*fileInfo, error) {
bucket, key, err := parsePath(uri)
func Stat(ctx context.Context, uri string, client s3iface.S3API) (*fileInfo, error) {
bucket, key, err := ParsePath(uri)
if err != nil {
return nil, err
}
Expand All @@ -77,11 +80,11 @@ func stat(ctx context.Context, uri string, client s3iface.S3API) (*fileInfo, err
}

func NewS3Reader(ctx context.Context, filepath string, client s3iface.S3API) (*S3Reader, error) {
info, err := stat(ctx, filepath, client)
info, err := Stat(ctx, filepath, client)
if err != nil {
return nil, err
}
bucket, key, err := parsePath(filepath)
bucket, key, err := ParsePath(filepath)
if err != nil {
return nil, err
}
Expand Down

0 comments on commit 38671d8

Please sign in to comment.