diff --git a/.gitignore b/.gitignore index 4818cf2..812d85e 100644 --- a/.gitignore +++ b/.gitignore @@ -25,3 +25,5 @@ cache/ server/ tmp/ main + +.DS_Store diff --git a/cmd/start/start.go b/cmd/start/start.go index bb890ff..028b6ee 100644 --- a/cmd/start/start.go +++ b/cmd/start/start.go @@ -7,7 +7,6 @@ import ( "io" "os" "runtime" - "strconv" pty "github.com/MCSManager/pty/console" "github.com/MCSManager/pty/utils" @@ -16,9 +15,8 @@ import ( ) var ( - dir, cmd, coder, ptySize, pid, mode string - cmds []string - colorAble, exhaustive, skipExistFile bool + dir, cmd, coder, ptySize string + cmds []string ) type PtyInfo struct { @@ -32,56 +30,24 @@ func init() { flag.StringVar(&cmd, "cmd", "[\"sh\"]", "command") } - flag.BoolVar(&colorAble, "color", true, "colorable (default true)") - flag.BoolVar(&skipExistFile, "s", false, "Skip Exist File (default false)") - flag.BoolVar(&exhaustive, "e", false, "Zip Exhaustive (default false)") flag.StringVar(&coder, "coder", "auto", "Coder") - flag.StringVar(&pid, "pid", "0", "detect pid info") flag.StringVar(&dir, "dir", ".", "command work path") flag.StringVar(&ptySize, "size", "80,50", "Initialize pty size, stdin will be forwarded directly") - flag.StringVar(&mode, "m", "pty", "set mode") } func Main() { flag.Parse() - args := flag.Args() - switch mode { - case "zip": - if err := utils.Zip(args[:len(args)-1], args[len(args)-1], utils.ZipCfg{Exhaustive: exhaustive}); err != nil { - fmt.Println(err.Error()) - os.Exit(1) - } - case "unzip": - if err := utils.Unzip(args[0], args[1], utils.UnzipCfg{CoderTypes: utils.CoderToType(coder), SkipExistFile: skipExistFile, Exhaustive: exhaustive}); err != nil { - fmt.Println(err.Error()) - os.Exit(1) - } - case "info": - runtime.GOMAXPROCS(2) - info := utils.NewInfo() - upid, err := strconv.ParseInt(pid, 10, 32) - if err != nil { - fmt.Println(err.Error()) - os.Exit(1) - } - utils.Detect(int32(upid), info) - pinfo, err := json.Marshal(info) - if err != nil { - fmt.Println(err.Error()) - os.Exit(1) - } - fmt.Println(string(pinfo)) - default: - runtime.GOMAXPROCS(6) - runPTY() - } + runPTY() } func runPTY() { - json.Unmarshal([]byte(cmd), &cmds) + if err := json.Unmarshal([]byte(cmd), &cmds); err != nil { + fmt.Println("[MCSMANAGER-PTY] Unmarshal command error: ", err) + return + } con := pty.New(utils.CoderToType(coder)) if err := con.ResizeWithString(ptySize); err != nil { - fmt.Printf("[MCSMANAGER-PTY] PTY ReSize Error: %v\n", err) + fmt.Printf("[MCSMANAGER-PTY] PTY Resize error: %v\n", err) return } err := con.Start(dir, cmds) @@ -90,12 +56,12 @@ func runPTY() { }) fmt.Println(string(info)) if err != nil { - fmt.Printf("[MCSMANAGER-PTY] Process Start Error: %v\n", err) + fmt.Printf("[MCSMANAGER-PTY] Process start error: %v\n", err) return } defer con.Close() handleStdIO(con) - con.Wait() + _, _ = con.Wait() } func handleStdIO(c pty.Console) { @@ -107,26 +73,14 @@ func handleStdIO(c pty.Console) { defer func() { _ = term.Restore(int(os.Stdin.Fd()), oldState) }() go func() { _, _ = io.Copy(c.StdIn(), os.Stdin) }() } else { - go io.Copy(c.StdIn(), os.Stdin) + go func() { _, _ = io.Copy(c.StdIn(), os.Stdin) }() } if runtime.GOOS == "windows" && c.StdErr() != nil { - var stdErr io.Reader - if colorAble { - stdErr = c.StdErr() - } else { - stdErr = colorable.NewNonColorableReader(c.StdErr()) - } - go io.Copy(colorable.NewColorableStderr(), stdErr) + go func() { _, _ = io.Copy(colorable.NewColorableStderr(), c.StdErr()) }() } handleStdOut(c) } func handleStdOut(c pty.Console) { - var stdOut io.Reader - if colorAble { - stdOut = c.StdOut() - } else { - stdOut = colorable.NewNonColorableReader(c.StdOut()) - } - io.Copy(colorable.NewColorableStdout(), stdOut) + _, _ = io.Copy(colorable.NewColorableStdout(), c.StdOut()) } diff --git a/console/console_windows.go b/console/console_windows.go index a850afe..a17618f 100644 --- a/console/console_windows.go +++ b/console/console_windows.go @@ -1,13 +1,14 @@ package console import ( - "bytes" + "embed" _ "embed" "fmt" "io" "os" "os/exec" "path/filepath" + "strings" "time" "github.com/MCSManager/pty/console/go-winpty" @@ -16,8 +17,8 @@ import ( mutex "github.com/juju/mutex/v2" ) -//go:embed winpty -var winpty_zip []byte +//go:embed all:winpty +var winpty_embed embed.FS var _ iface.Console = (*console)(nil) @@ -44,7 +45,12 @@ func (c *console) Start(dir string, command []string) error { defer r.Release() if dir, err = filepath.Abs(dir); err != nil { return err - } else if err := os.Chdir(dir); err != nil { + } + if err := os.Chdir(dir); err != nil { + return err + } + dllDir, err := c.findDll() + if err != nil { return err } cmd, err := c.buildCmd(command) @@ -52,7 +58,7 @@ func (c *console) Start(dir string, command []string) error { return err } option := winpty.Options{ - DllDir: filepath.Join(os.TempDir(), "pty_winpty"), + DllDir: dllDir, Command: cmd, Dir: dir, Env: c.env, @@ -65,12 +71,7 @@ func (c *console) Start(dir string, command []string) error { var pty *winpty.WinPTY if pty, err = winpty.OpenWithOptions(option); err != nil { - if option.DllDir, err = c.findDll(); err != nil { - return err - } - if pty, err = winpty.OpenWithOptions(option); err != nil { - return err - } + return err } c.stdIn = pty.Stdin c.stdOut = pty.Stdout @@ -79,16 +80,16 @@ func (c *console) Start(dir string, command []string) error { return nil } -// splice command func (c *console) buildCmd(args []string) (string, error) { if len(args) == 0 { return "", ErrInvalidCmd } - var cmds = fmt.Sprintf("cmd /C chcp %s > nul & ", utils.CodePage(c.coder)) - for _, v := range args { - cmds += v + ` ` - } - return cmds[:len(cmds)-1], nil + var cmds = fmt.Sprintf( + "cmd /C chcp %s > nul & %s", + utils.CodePage(c.coder), + strings.Join(args, " "), + ) + return cmds, nil } type fakeClock struct { @@ -110,9 +111,39 @@ func (c *console) findDll() (string, error) { return "", err } - return dllDir, utils.UnzipWithFile(bytes.NewReader(winpty_zip), dllDir, utils.UnzipCfg{ - CoderTypes: utils.T_UTF8, - }) + dir, err := winpty_embed.ReadDir("winpty") + if err != nil { + return "", fmt.Errorf("read embed dir error: %w", err) + } + + for _, de := range dir { + info, err := de.Info() + if err != nil { + return "", err + } + var exist bool + df, err := os.Stat(filepath.Join(dllDir, de.Name())) + if err != nil { + if !os.IsNotExist(err) { + return "", err + } + } else { + if !df.ModTime().Before(info.ModTime()) { + exist = true + } + } + if !exist { + data, err := winpty_embed.ReadFile(fmt.Sprintf("winpty/%s", de.Name())) + if err != nil { + return "", fmt.Errorf("read embed file error: %w", err) + } + if err := os.WriteFile(filepath.Join(dllDir, de.Name()), data, os.ModePerm); err != nil { + return "", fmt.Errorf("write file error: %w", err) + } + } + } + + return dllDir, nil } // set pty window size diff --git a/console/winpty b/console/winpty deleted file mode 100644 index 0f7916a..0000000 Binary files a/console/winpty and /dev/null differ diff --git a/console/winpty/winpty-agent.exe b/console/winpty/winpty-agent.exe new file mode 100644 index 0000000..e59d64d Binary files /dev/null and b/console/winpty/winpty-agent.exe differ diff --git a/console/winpty/winpty.dll b/console/winpty/winpty.dll new file mode 100644 index 0000000..e192a49 Binary files /dev/null and b/console/winpty/winpty.dll differ diff --git a/go.mod b/go.mod index 14c1dc2..a768001 100644 --- a/go.mod +++ b/go.mod @@ -3,14 +3,14 @@ module github.com/MCSManager/pty go 1.18 require ( - github.com/creack/pty v1.1.18 + github.com/creack/pty v1.1.21 github.com/juju/mutex/v2 v2.0.0 github.com/klauspost/compress v1.16.5 github.com/mholt/archiver/v4 v4.0.0-alpha.8 github.com/shirou/gopsutil/v3 v3.23.4 - github.com/zijiren233/go-colorable v0.0.0-20230522040028-05f4e204585c - golang.org/x/term v0.8.0 - golang.org/x/text v0.9.0 + github.com/zijiren233/go-colorable v0.0.0-20230930131441-997304c961cb + golang.org/x/term v0.18.0 + golang.org/x/text v0.14.0 ) require ( @@ -27,7 +27,7 @@ require ( github.com/juju/errors v1.0.0 // indirect github.com/klauspost/pgzip v1.2.6 // indirect github.com/lufia/plan9stats v0.0.0-20230326075908-cb1d2100619a // indirect - github.com/mattn/go-isatty v0.0.19 // indirect + github.com/mattn/go-isatty v0.0.20 // indirect github.com/nwaples/rardecode/v2 v2.0.0-beta.2 // indirect github.com/pierrec/lz4/v4 v4.1.17 // indirect github.com/power-devops/perfstat v0.0.0-20221212215047-62379fc7944b // indirect @@ -38,5 +38,5 @@ require ( github.com/ulikunitz/xz v0.5.11 // indirect github.com/yusufpapurcu/wmi v1.2.3 // indirect go4.org v0.0.0-20230225012048-214862532bf5 // indirect - golang.org/x/sys v0.8.0 // indirect + golang.org/x/sys v0.18.0 // indirect ) diff --git a/go.sum b/go.sum index c1759cf..7877768 100644 --- a/go.sum +++ b/go.sum @@ -36,6 +36,8 @@ github.com/connesc/cipherio v0.2.1 h1:FGtpTPMbKNNWByNrr9aEBtaJtXjqOzkIXNYJp6OEyc github.com/connesc/cipherio v0.2.1/go.mod h1:ukY0MWJDFnJEbXMQtOcn2VmTpRfzcTz4OoVrWGGJZcA= github.com/creack/pty v1.1.18 h1:n56/Zwd5o6whRC5PMGretI4IdRLlmBXYNjScPaBgsbY= github.com/creack/pty v1.1.18/go.mod h1:MOBLtS5ELjhRRrroQr9kyvTxUAFNvYEK993ew/Vr4O4= +github.com/creack/pty v1.1.21 h1:1/QdRyBaHHJP61QkWMXlOIBfsgdDeeKfK8SYVUWJKf0= +github.com/creack/pty v1.1.21/go.mod h1:MOBLtS5ELjhRRrroQr9kyvTxUAFNvYEK993ew/Vr4O4= github.com/davecgh/go-spew v1.1.0/go.mod h1:J7Y8YcW2NihsgmVo/mv3lAwl/skON4iLHjSsI+c5H38= github.com/davecgh/go-spew v1.1.1 h1:vj9j/u1bqnvCEfJOwUhtlOARqs3+rkHYY13jYWTU97c= github.com/davecgh/go-spew v1.1.1/go.mod h1:J7Y8YcW2NihsgmVo/mv3lAwl/skON4iLHjSsI+c5H38= @@ -127,6 +129,8 @@ github.com/mattn/go-isatty v0.0.17 h1:BTarxUcIeDqL27Mc+vyvdWYSL28zpIhv3RoTdsLMPn github.com/mattn/go-isatty v0.0.17/go.mod h1:kYGgaQfpe5nmfYZH+SKPsOc2e4SrIfOl2e/yFXSvRLM= github.com/mattn/go-isatty v0.0.19 h1:JITubQf0MOLdlGRuRq+jtsDlekdYPia9ZFsB8h/APPA= github.com/mattn/go-isatty v0.0.19/go.mod h1:W+V8PltTTMOvKvAeJH7IuucS94S2C6jfK/D7dTCTo3Y= +github.com/mattn/go-isatty v0.0.20 h1:xfD0iDuEKnDkl03q4limB+vH+GxLEtL/jb4xVJSWWEY= +github.com/mattn/go-isatty v0.0.20/go.mod h1:W+V8PltTTMOvKvAeJH7IuucS94S2C6jfK/D7dTCTo3Y= github.com/mholt/archiver/v4 v4.0.0-alpha.7.0.20230223155640-de8cf229f727 h1:9ScivJvfYiMaoHng1p5wvhM4iAKzXUVVKBIo+QotEAI= github.com/mholt/archiver/v4 v4.0.0-alpha.7.0.20230223155640-de8cf229f727/go.mod h1:5f7FUYGXdJWUjESffJaYR4R60VhnHxb2X3T1teMyv5A= github.com/mholt/archiver/v4 v4.0.0-alpha.8 h1:tRGQuDVPh66WCOelqe6LIGh0gwmfwxUrSSDunscGsRM= @@ -181,6 +185,8 @@ github.com/zijiren233/go-colorable v0.0.0-20230304035935-641eddfc7ecf h1:EwTsEmz github.com/zijiren233/go-colorable v0.0.0-20230304035935-641eddfc7ecf/go.mod h1:TJFyVPDSW/YIGewz0BLYt4sACOV0xXJbNbfaRMJPXHc= github.com/zijiren233/go-colorable v0.0.0-20230522040028-05f4e204585c h1:5Q2rC1jyLWlhjbGii+94yK7xefMstT2WXJnsR5G7qDA= github.com/zijiren233/go-colorable v0.0.0-20230522040028-05f4e204585c/go.mod h1:TJFyVPDSW/YIGewz0BLYt4sACOV0xXJbNbfaRMJPXHc= +github.com/zijiren233/go-colorable v0.0.0-20230930131441-997304c961cb h1:0DyOxf/TbbGodHhOVHNoPk+7v/YBJACs22gKpKlatWw= +github.com/zijiren233/go-colorable v0.0.0-20230930131441-997304c961cb/go.mod h1:6TCzjDiQ8+5gWZiwsC3pnA5M0vUy2jV2Y7ciHJh729g= go.opencensus.io v0.21.0/go.mod h1:mSImk1erAIZhrmZN+AvHh14ztQfjbGwt4TtuofqLduU= go.opencensus.io v0.22.0/go.mod h1:+kGneAE2xo2IficOXnaByMWTGM9T73dGwxeWcUqIpI8= go.opencensus.io v0.22.2/go.mod h1:yxeiOL68Rb0Xd1ddK5vPZ/oVn4vY4Ynel7k9FzqtOIw= @@ -275,12 +281,16 @@ golang.org/x/sys v0.6.0/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg= golang.org/x/sys v0.7.0/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg= golang.org/x/sys v0.8.0 h1:EBmGv8NaZBZTWvrbjNoL6HVt+IVy3QDQpJs7VRIw3tU= golang.org/x/sys v0.8.0/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg= +golang.org/x/sys v0.18.0 h1:DBdB3niSjOA/O0blCZBqDefyWNYveAYMNF1Wum0DYQ4= +golang.org/x/sys v0.18.0/go.mod h1:/VUhepiaJMQUp4+oa/7Zr1D23ma6VTLIYjOOTFZPUcA= golang.org/x/term v0.0.0-20201126162022-7de9c90e9dd1/go.mod h1:bj7SfCRtBDWHUb9snDiAeCFNEtKQo2Wmx5Cou7ajbmo= golang.org/x/term v0.0.0-20210927222741-03fcf44c2211/go.mod h1:jbD1KX2456YbFQfuXm/mYQcufACuNUgVhRMnK/tPxf8= golang.org/x/term v0.5.0 h1:n2a8QNdAb0sZNpU9R1ALUXBbY+w51fCQDN+7EdxNBsY= golang.org/x/term v0.5.0/go.mod h1:jMB1sMXY+tzblOD4FWmEbocvup2/aLOaQEp7JmGp78k= golang.org/x/term v0.8.0 h1:n5xxQn2i3PC0yLAbjTpNT85q/Kgzcr2gIoX9OrJUols= golang.org/x/term v0.8.0/go.mod h1:xPskH00ivmX89bAKVGSKKtLOWNx2+17Eiy94tnKShWo= +golang.org/x/term v0.18.0 h1:FcHjZXDMxI8mM3nwhX9HlKop4C0YQvCVCdwYl2wOtE8= +golang.org/x/term v0.18.0/go.mod h1:ILwASektA3OnRv7amZ1xhE/KTR+u50pbXfZ03+6Nx58= golang.org/x/text v0.0.0-20170915032832-14c0d48ead0c/go.mod h1:NqM8EUOU14njkJ3fqMW+pc6Ldnwhi/IjpwHt7yyuwOQ= golang.org/x/text v0.3.0/go.mod h1:NqM8EUOU14njkJ3fqMW+pc6Ldnwhi/IjpwHt7yyuwOQ= golang.org/x/text v0.3.1-0.20180807135948-17ff2d5776d2/go.mod h1:NqM8EUOU14njkJ3fqMW+pc6Ldnwhi/IjpwHt7yyuwOQ= @@ -291,6 +301,8 @@ golang.org/x/text v0.7.0 h1:4BRB4x83lYWy72KwLD/qYDuTu7q9PjSagHvijDw7cLo= golang.org/x/text v0.7.0/go.mod h1:mrYo+phRRbMaCq/xk9113O4dZlRixOauAjOtrjsXDZ8= golang.org/x/text v0.9.0 h1:2sjJmO8cDvYveuX97RDLsxlyUxLl+GHoLxBiRdHllBE= golang.org/x/text v0.9.0/go.mod h1:e1OnstbJyHTd6l/uOt8jFFHp6TRDWZR/bV3emEE/zU8= +golang.org/x/text v0.14.0 h1:ScX5w1eTa3QqT8oi6+ziP7dTV1S2+ALU0bI+0zXKWiQ= +golang.org/x/text v0.14.0/go.mod h1:18ZOQIKpY8NJVqYksKHtTdi31H5itFRjB5/qKTNYzSU= golang.org/x/time v0.0.0-20181108054448-85acf8d2951c/go.mod h1:tRJNPiyCQ0inRvYxbN9jk5I+vvW/OXSQhTDSoE431IQ= golang.org/x/time v0.0.0-20190308202827-9d24e82272b4/go.mod h1:tRJNPiyCQ0inRvYxbN9jk5I+vvW/OXSQhTDSoE431IQ= golang.org/x/tools v0.0.0-20180917221912-90fa682c2a6e/go.mod h1:n7NCudcB/nEzxVGmLbDWY5pfWTLqBcC2KZ6jyYvM4mQ= diff --git a/utils/coder.go b/utils/coder.go index ab8cae7..45ed499 100644 --- a/utils/coder.go +++ b/utils/coder.go @@ -130,65 +130,3 @@ func newEeCoder(coder CoderType) *encoding.Encoder { } return decoder } - -// 先判断是否是UTF8再判断是否是其它编码才有意义 -func isUtf8(data []byte) (bool, CoderType) { - i := 0 - for i < len(data) { - if (data[i] & 0x80) == 0x00 { - i++ - continue - } else if num := preNUm(data[i]); num > 2 { - i++ - for j := 0; j < num-1; j++ { - //判断后面的 num - 1 个字节是不是都是10开头 - if (data[i] & 0xc0) != 0x80 { - return false, T_UTF8 - } - i++ - } - } else { - //其他情况说明不是utf-8 - return false, T_UTF8 - } - } - return true, T_UTF8 -} - -func isGBK(data []byte) (bool, CoderType) { - length := len(data) - var i int = 0 - for i < length { - if data[i] <= 0x7f { - //编码0~127,只有一个字节的编码,兼容ASCII码 - i++ - continue - } else if i+1 < length { - //大于127的使用双字节编码,落在gbk编码范围内的字符 - if data[i] >= 0x81 && - data[i] <= 0xfe && - data[i+1] >= 0x40 && - data[i+1] <= 0xfe && - data[i+1] != 0x7f { - i += 2 - continue - } - } - return false, T_GBK - } - return true, T_GBK -} - -func preNUm(data byte) int { - var mask byte = 0x80 - var num int = 0 - for i := 0; i < 8; i++ { - if (data & mask) == mask { - num++ - mask = mask >> 1 - } else { - break - } - } - return num -} diff --git a/utils/detect.go b/utils/detect.go deleted file mode 100644 index 98b122d..0000000 --- a/utils/detect.go +++ /dev/null @@ -1,52 +0,0 @@ -package utils - -import ( - "sync" - "sync/atomic" - "time" - - "github.com/shirou/gopsutil/v3/process" -) - -type Info struct { - Mem uint64 `json:"mem"` - Cpu float64 `json:"cpu"` - NumConn int32 `json:"numConn"` - IOReadSpeed uint64 `json:"ioReadSpeed"` - IOWriteSpeed uint64 `json:"ioWriteSpeed"` - lock *sync.Mutex -} - -func NewInfo() *Info { - return &Info{lock: &sync.Mutex{}} -} - -func Detect(pid int32, info *Info) { - p, err := process.NewProcess(pid) - if err != nil { - return - } - if children, err := p.Children(); err == nil { - for _, v := range children { - go Detect(v.Pid, info) - } - } - if conn, err := p.Connections(); err == nil { - atomic.AddInt32(&info.NumConn, int32(len(conn))) - } - if io1, err := p.IOCounters(); err == nil { - time.Sleep(time.Millisecond * 250) - if io2, err := p.IOCounters(); err == nil { - atomic.AddUint64(&info.IOReadSpeed, (io2.ReadBytes-io1.ReadBytes)*4) - atomic.AddUint64(&info.IOWriteSpeed, (io2.WriteBytes-io1.WriteBytes)*4) - } - } - if mem, err := p.MemoryInfo(); err == nil { - atomic.AddUint64(&info.Mem, mem.RSS) - } - if cpu, err := p.CPUPercent(); err == nil { - info.lock.Lock() - info.Cpu += cpu - info.lock.Unlock() - } -} diff --git a/utils/unzip.go b/utils/unzip.go deleted file mode 100644 index 938c765..0000000 --- a/utils/unzip.go +++ /dev/null @@ -1,172 +0,0 @@ -package utils - -import ( - "bufio" - "context" - "errors" - "fmt" - "io" - "os" - "path/filepath" - - archiver "github.com/mholt/archiver/v4" - - "golang.org/x/text/encoding" - "golang.org/x/text/transform" -) - -const bufSize = 512 * 1024 - -type UnzipCfg struct { - BufferSize int - Ctx context.Context - CoderTypes CoderType - SkipExistFile, Exhaustive bool -} - -func Unzip(zipPath string, TargetPath string, cfg UnzipCfg) (err error) { - if zipPath, err = filepath.Abs(zipPath); err != nil { - return - } - zipFile, err := os.Open(zipPath) - if err != nil { - return - } - defer zipFile.Close() - return UnzipWithFile(zipFile, TargetPath, cfg) -} - -func UnzipWithFile(zipFile io.Reader, TargetPath string, cfg UnzipCfg) error { - _initZipCompressor() - if cfg.Ctx == nil { - cfg.Ctx = context.Background() - } - if cfg.BufferSize == 0 { - cfg.BufferSize = bufSize - } - seek, ok := zipFile.(io.Seeker) - if !ok { - return errors.New("seek file error") - } - var err error - if TargetPath, err = filepath.Abs(TargetPath); err != nil { - return err - } - err = os.MkdirAll(TargetPath, os.ModePerm) - if err != nil { - return err - } - format, _, err := archiver.Identify("", zipFile) - if err != nil { - return err - } - if cfg.CoderTypes == T_Auto { - m := zipEncode(cfg.Ctx, format, zipFile, isUtf8, isGBK) - _, err = seek.Seek(0, io.SeekStart) - if err != nil { - return err - } - if m[T_GBK] && !m[T_UTF8] { - cfg.CoderTypes = T_GBK - err = decode(format, zipFile, TargetPath, cfg) - } else { - err = decode(format, zipFile, TargetPath, cfg) - } - } else { - err = decode(format, zipFile, TargetPath, cfg) - } - return err -} - -func zipEncode(ctx context.Context, format archiver.Format, r io.Reader, fun ...func(data []byte) (bool, CoderType)) (res map[CoderType]bool) { - res = make(map[CoderType]bool) - if ex, ok := format.(archiver.Extractor); ok { - ex.Extract(ctx, r, nil, func(ctx context.Context, f archiver.File) error { - for _, fn := range fun { - select { - case <-ctx.Done(): - return ctx.Err() - default: - ok, name := fn([]byte(f.Name())) - if b, o := res[name]; o { - if !b { - continue - } else { - res[name] = ok - } - } else { - res[name] = ok - } - } - } - return nil - }) - } - return -} - -func decode(format archiver.Format, r io.Reader, TargetPath string, cfg UnzipCfg) error { - var decoder *encoding.Decoder - if cfg.CoderTypes != T_Auto { - decoder = newDeCoder(cfg.CoderTypes) - } - if ex, ok := format.(archiver.Extractor); ok { - buffer := make([]byte, cfg.BufferSize) - return ex.Extract(cfg.Ctx, r, nil, func(ctx context.Context, f archiver.File) error { - select { - case <-ctx.Done(): - return ctx.Err() - default: - var result string - if decoder == nil { - result = f.NameInArchive - } else { - var err error - result, _, err = transform.String(decoder, f.NameInArchive) - if err != nil { - fmt.Printf("File %s err: %v", f.NameInArchive, err) - return err - } - } - if cfg.Exhaustive { - fmt.Println(result) - } - fpath := filepath.Join(TargetPath, result) - if f.IsDir() { - return os.MkdirAll(fpath, f.Mode()) - } - if cfg.SkipExistFile { - _, err := os.Stat(fpath) - if err == nil { - return err - } - } - inFile, err := f.Open() - if err != nil { - return err - } - defer inFile.Close() - - if err := os.MkdirAll(filepath.Dir(fpath), os.ModePerm); err != nil { - return err - } - file, err := os.OpenFile(fpath, os.O_WRONLY|os.O_CREATE|os.O_TRUNC, f.Mode()) - if err != nil { - return err - } - defer file.Close() - var outFile io.Writer - if f.Size() > bufSize { - buf := bufio.NewWriterSize(file, 4*bufSize) - outFile = buf - defer buf.Flush() - } else { - outFile = file - } - _, err = io.CopyBuffer(outFile, inFile, buffer) - return err - } - }) - } - return errors.New("format.(archiver.Extractor) err") -} diff --git a/utils/zip.go b/utils/zip.go deleted file mode 100644 index dde5148..0000000 --- a/utils/zip.go +++ /dev/null @@ -1,133 +0,0 @@ -package utils - -import ( - "bufio" - "compress/flate" - "context" - "errors" - "fmt" - "io" - "os" - "path/filepath" - "strings" - "sync" - - "github.com/klauspost/compress/zip" - - archiver "github.com/mholt/archiver/v4" -) - -var initZipCompressor = sync.Once{} - -func _initZipCompressor() { - initZipCompressor.Do(func() { - zip.RegisterCompressor(flate.BestCompression, func(w io.Writer) (io.WriteCloser, error) { - return flate.NewWriter(w, flate.BestCompression) - }) - zip.RegisterCompressor(flate.BestSpeed, func(w io.Writer) (io.WriteCloser, error) { - return flate.NewWriter(w, flate.BestSpeed) - }) - zip.RegisterDecompressor(flate.BestCompression, flate.NewReader) - zip.RegisterDecompressor(flate.BestSpeed, flate.NewReader) - }) -} - -type ZipCfg struct { - BufferSize int - Ctx context.Context - Exhaustive bool -} - -func Zip(FilePath []string, ZipPath string, cfg ZipCfg) error { - _initZipCompressor() - if cfg.Ctx == nil { - cfg.Ctx = context.Background() - } - if cfg.BufferSize == 0 { - cfg.BufferSize = bufSize - } - if len(FilePath) == 0 { - return errors.New("file is nil") - } - var err error - FilePath[0], err = filepath.Abs(FilePath[0]) - if err != nil { - return err - } - var baseDir = filepath.Dir(FilePath[0]) - if len(FilePath) == 1 { - fi, err := os.Stat(FilePath[0]) - if err != nil { - return err - } - if fi.IsDir() { - baseDir = FilePath[0] - } - } - for k, v := range FilePath[1:] { - FilePath[k+1], err = filepath.Abs(v) - if err != nil { - return err - } - if filepath.Dir(FilePath[k+1]) != baseDir { - return errors.New("base dir err") - } - } - ZipPath, err = filepath.Abs(ZipPath) - if err != nil { - return err - } - zipExi := strings.ToLower(filepath.Ext(ZipPath)) - var format archiver.CompressedArchive - switch zipExi { - case "": - ZipPath += ".zip" - format = archiver.CompressedArchive{ - Archival: archiver.Zip{Compression: zip.Deflate, SelectiveCompression: true}, - } - case ".tar": - format = archiver.CompressedArchive{ - Archival: archiver.Tar{}, - } - case ".gz", ".tgz": - format = archiver.CompressedArchive{ - Compression: archiver.Gz{CompressionLevel: flate.DefaultCompression, Multithreaded: true}, - Archival: archiver.Tar{}, - } - case ".zip": - format = archiver.CompressedArchive{ - Archival: archiver.Zip{Compression: zip.Deflate, SelectiveCompression: true}, - } - default: - return errors.New("not support this exi") - } - fileMap := make(map[string]string) - for _, fPath := range FilePath { - select { - case <-cfg.Ctx.Done(): - return cfg.Ctx.Err() - default: - if cfg.Exhaustive { - fmt.Println(fPath) - } - fileMap[fPath] = strings.TrimPrefix(strings.TrimPrefix(fPath, baseDir), string(os.PathSeparator)) - } - } - files, err := archiver.FilesFromDisk(nil, fileMap) - if err != nil { - return err - } - err = os.MkdirAll(filepath.Dir(ZipPath), os.ModePerm) - if err != nil { - return err - } - zipfile, err := os.Create(ZipPath) - if err != nil { - return err - } - defer zipfile.Close() - fmt.Println("Archiving, please wait...") - buf := bufio.NewWriterSize(zipfile, cfg.BufferSize) - defer buf.Flush() - return format.Archive(cfg.Ctx, buf, files) -}