diff --git a/main.go b/main.go index caaf8feb..0689b7cd 100644 --- a/main.go +++ b/main.go @@ -12,6 +12,7 @@ import ( "net/http" "os" "os/signal" + "regexp" "strconv" "strings" "sync" @@ -100,24 +101,13 @@ func runCmd() error { return errors.Errorf("tab width must be > 0: %d", *flagTabWidth) } - var sl tree.StatementList - if len(*flagStmts) != 0 { - for _, exec := range *flagStmts { - stmts, err := parser.Parse(exec) - if err != nil { - return err - } - sl = append(sl, stmts...) - } - } else { + sl := *flagStmts + if len(sl) == 0 { in, err := ioutil.ReadAll(os.Stdin) if err != nil { return err } - sl, err = parser.Parse(string(in)) - if err != nil { - return err - } + sl = append(sl, string(in)) } cfg := tree.DefaultPrettyCfg() @@ -130,16 +120,68 @@ func runCmd() error { cfg.Align = tree.PrettyAlignAndDeindent } - for _, s := range sl { - fmt.Print(cfg.Pretty(s)) - if len(sl) > 1 { - fmt.Print(";") - } - fmt.Println() + res, err := fmtsql(cfg, sl) + if err != nil { + return err } + fmt.Println(res) return nil } +var ( + ignoreComments = regexp.MustCompile(`^--.*\s*`) +) + +func fmtsql(cfg tree.PrettyCfg, stmts []string) (string, error) { + var prettied strings.Builder + for _, stmt := range stmts { + for len(stmt) > 0 { + stmt = strings.TrimSpace(stmt) + hasContent := false + // Trim comments, preserving whitespace after them. + for { + found := ignoreComments.FindString(stmt) + if found == "" { + break + } + // Remove trailing whitespace but keep up to 2 newlines. + prettied.WriteString(strings.TrimSpace(found)) + newlines := strings.Count(found, "\n") + if newlines > 2 { + newlines = 2 + } + prettied.WriteString(strings.Repeat("\n", newlines)) + stmt = stmt[len(found):] + hasContent = true + } + // Split by semicolons + next := stmt + scan := parser.MakeScanner(stmt) + if pos := scan.Until(';'); pos > 0 { + next = stmt[:pos] + stmt = stmt[pos:] + } else { + stmt = "" + } + // This should only return 0 or 1 responses. + allParsed, err := parser.Parse(next) + if err != nil { + return "", err + } + for _, parsed := range allParsed { + prettied.WriteString(cfg.Pretty(parsed)) + prettied.WriteString(";\n") + hasContent = true + } + if hasContent { + prettied.WriteString("\n") + } + } + } + + return strings.TrimSpace(prettied.String()), nil +} + func serveHTTP(spec Specification) { fmt.Printf("SPEC: %#v\n", spec) base := template.Must(template.New("base").Parse(Base)) @@ -234,12 +276,12 @@ func (db dbCache) Delete(ctx context.Context, key string) error { return err } -func wrap(f func(http.ResponseWriter, *http.Request) []string) http.HandlerFunc { +func wrap(f func(http.ResponseWriter, *http.Request) fmtResponse) http.HandlerFunc { return func(w http.ResponseWriter, r *http.Request) { res := f(w, r) if r.FormValue("json") == "" { w.Header().Add("Content-Type", "text/plain") - fmt.Fprintln(w, strings.Join(res, ";\n")) + w.Write([]byte(res.Data)) } else { w.Header().Add("Content-Type", "application/json") if err := json.NewEncoder(w).Encode(res); err != nil { @@ -249,11 +291,16 @@ func wrap(f func(http.ResponseWriter, *http.Request) []string) http.HandlerFunc } } +type fmtResponse struct { + Data string + Error bool +} + var cache = struct { sync.RWMutex - m map[string][]string + m map[string]fmtResponse }{ - m: make(map[string][]string), + m: make(map[string]fmtResponse), } func parseBool(val string) (bool, error) { @@ -267,7 +314,7 @@ func parseBool(val string) (bool, error) { } } -func Fmt(w http.ResponseWriter, r *http.Request) []string { +func Fmt(w http.ResponseWriter, r *http.Request) fmtResponse { cache.RLock() hit, ok := cache.m[r.URL.RawQuery] cache.RUnlock() @@ -275,19 +322,26 @@ func Fmt(w http.ResponseWriter, r *http.Request) []string { return hit } - res := fmtsql(r) + res, err := fmtSQLRequest(r) + response := fmtResponse{ + Data: res, + Error: err != nil, + } + if err != nil { + response.Data = err.Error() + } cache.Lock() if len(cache.m) > 10000 { for k := range cache.m { delete(cache.m, k) } } - cache.m[r.URL.RawQuery] = res + cache.m[r.URL.RawQuery] = response cache.Unlock() - return res + return response } -func fmtsql(r *http.Request) []string { +func fmtSQLRequest(r *http.Request) (string, error) { sql := r.FormValue("sql") trimmed := strings.Join(strings.Fields(sql), " ") if len(trimmed) > 100 { @@ -296,28 +350,24 @@ func fmtsql(r *http.Request) []string { n, err := strconv.Atoi(r.FormValue("n")) if err != nil { - return []string{"error", err.Error()} + return "", err } log.Printf("fmt (sqln: %d, n: %d): %s", len(sql), n, trimmed) tabWidth, err := strconv.Atoi(r.FormValue("indent")) if err != nil { - return []string{"error", err.Error()} + return "", err } simplify, err := parseBool(r.FormValue("simplify")) if err != nil { - return []string{"error", err.Error()} + return "", err } align, err := strconv.Atoi(r.FormValue("align")) if err != nil { - return []string{"error", err.Error()} + return "", err } spaces, err := parseBool(r.FormValue("spaces")) if err != nil { - return []string{"error", err.Error()} - } - sl, err := parser.Parse(sql) - if err != nil { - return []string{"error", err.Error()} + return "", err } pcfg := tree.DefaultPrettyCfg() @@ -327,11 +377,7 @@ func fmtsql(r *http.Request) []string { pcfg.Simplify = simplify pcfg.Align = tree.PrettyAlignMode(align) - res := make([]string, len(sl)) - for i, s := range sl { - res[i] = pcfg.Pretty(s) - } - return res + return fmtsql(pcfg, []string{sql}) } const ( @@ -590,12 +636,12 @@ function range() { resp => { working = false; resp.json().then(data => { - if (data.length === 2 && data[0].includes('error')) { - fmt.innerText = data[1]; + if (data.Error) { + fmt.innerText = data.Data; actualWidth.innerText = ''; actualBytes.innerText = ''; } else { - fmtText = data.map(d => d + ';').join('\n\n'); + fmtText = data.Data tabSpaces = " ".repeat(viw); actualWidth.innerText = Math.max(...fmtText.split('\n').map(v => v.replace(/\t/g, tabSpaces).length)); actualBytes.innerText = fmtText.length;