diff --git a/goastgen/libgoastgen.go b/goastgen/libgoastgen.go index a382374..1ba1839 100644 --- a/goastgen/libgoastgen.go +++ b/goastgen/libgoastgen.go @@ -343,45 +343,52 @@ func processStruct(node interface{}, objPtrValue reflect.Value, fset *token.File // We will iterate through each field process each field according to its reflect.Kind type. for i := 0; i < elementType.NumField(); i++ { field := elementType.Field(i) - value := elementValueObj.Field(i) - fieldKind := value.Type().Kind() - - // If object is defined with field type as interface{} and assigned with pointer value. - // We need to first fetch the element from the interface - if fieldKind == reflect.Interface { - fieldKind = value.Elem().Kind() - value = value.Elem() - } + if field.Name != "Obj" { + // It looks like "Obj" field refers to the original method node from inside call expression node. + // In the event call expression node gets encountered earlier than the FuncDecl node. + // We process this object inside Call Expression and refer to this node inside FuncDecl node. + // However, while processing the AST, we need to process the Function Declaration with FuncDecl Node. + // As we don't use the "Obj" for any reference + value := elementValueObj.Field(i) + fieldKind := value.Type().Kind() + + // If object is defined with field type as interface{} and assigned with pointer value. + // We need to first fetch the element from the interface + if fieldKind == reflect.Interface { + fieldKind = value.Elem().Kind() + value = value.Elem() + } - var ptrValue reflect.Value + var ptrValue reflect.Value - if fieldKind == reflect.Pointer { - // NOTE: This handles only one level of pointer. At this moment we don't expect to get pointer to pointer. - // This will fetch the reflect.Kind of object pointed to by this field pointer - fieldKind = value.Type().Elem().Kind() - // This will fetch the reflect.Value of object pointed to by this field pointer. - ptrValue = value - // capturing the reflect.Value of the pointer if it's a pointer to be passed to recursive processStruct method. - value = value.Elem() - } - // In case the node is pointer, it will check if given Value contains valid pointer address. - if value.IsValid() { - switch fieldKind { - case reflect.String, reflect.Int, reflect.Bool, reflect.Int8, reflect.Int16, reflect.Int32, reflect.Int64, reflect.Float32, reflect.Float64, reflect.Uint, reflect.Uint8, reflect.Uint16, reflect.Uint32, reflect.Uint64: - if value.Type().String() == "token.Token" { - objectMap[field.Name] = value.Interface().(token.Token).String() - } else { - objectMap[field.Name] = value.Interface() + if fieldKind == reflect.Pointer { + // NOTE: This handles only one level of pointer. At this moment we don't expect to get pointer to pointer. + // This will fetch the reflect.Kind of object pointed to by this field pointer + fieldKind = value.Type().Elem().Kind() + // This will fetch the reflect.Value of object pointed to by this field pointer. + ptrValue = value + // capturing the reflect.Value of the pointer if it's a pointer to be passed to recursive processStruct method. + value = value.Elem() + } + // In case the node is pointer, it will check if given Value contains valid pointer address. + if value.IsValid() { + switch fieldKind { + case reflect.String, reflect.Int, reflect.Bool, reflect.Int8, reflect.Int16, reflect.Int32, reflect.Int64, reflect.Float32, reflect.Float64, reflect.Uint, reflect.Uint8, reflect.Uint16, reflect.Uint32, reflect.Uint64: + if value.Type().String() == "token.Token" { + objectMap[field.Name] = value.Interface().(token.Token).String() + } else { + objectMap[field.Name] = value.Interface() + } + case reflect.Struct: + objectMap[field.Name] = processStruct(value.Interface(), ptrValue, fset, lastNodeId, nodeAddressMap) + case reflect.Map: + objectMap[field.Name] = processMap(value.Interface(), fset, lastNodeId, nodeAddressMap) + case reflect.Array, reflect.Slice: + objectMap[field.Name] = processArrayOrSlice(value.Interface(), fset, lastNodeId, nodeAddressMap) + default: + log.SetPrefix("[WARNING]") + log.Println(getLogPrefix(), field.Name, "- of Kind ->", fieldKind, "- not handled") } - case reflect.Struct: - objectMap[field.Name] = processStruct(value.Interface(), ptrValue, fset, lastNodeId, nodeAddressMap) - case reflect.Map: - objectMap[field.Name] = processMap(value.Interface(), fset, lastNodeId, nodeAddressMap) - case reflect.Array, reflect.Slice: - objectMap[field.Name] = processArrayOrSlice(value.Interface(), fset, lastNodeId, nodeAddressMap) - default: - log.SetPrefix("[WARNING]") - log.Println(getLogPrefix(), field.Name, "- of Kind ->", fieldKind, "- not handled") } } } diff --git a/main.go b/main.go index f5f59b4..0651f54 100644 --- a/main.go +++ b/main.go @@ -7,6 +7,7 @@ import ( "os" "path/filepath" "privado.ai/goastgen/goastgen" + "runtime" "strings" ) @@ -17,6 +18,38 @@ func main() { processRequest(out, inputPath) } +func processFile(out string, inputPath string, path string, info os.FileInfo, resultErr chan error, sem chan int) { + sem <- 1 + defer func() { + <-sem + }() + var outFile = "" + var jsonResult string + var err error + directory := filepath.Dir(path) + if out == ".ast" { + outFile = filepath.Join(inputPath, out, strings.ReplaceAll(directory, inputPath, ""), info.Name()+".json") + } else { + outFile = filepath.Join(out, strings.ReplaceAll(directory, inputPath, ""), info.Name()+".json") + } + if strings.HasSuffix(info.Name(), ".go") { + jsonResult, err = goastgen.ParseAstFromFile(path) + } else if strings.HasSuffix(info.Name(), ".mod") { + jsonResult, err = goastgen.ParseModFromFile(path) + } + if err != nil { + fmt.Printf("Failed to generate AST for %s \n", path) + } else { + err = writeFileContents(outFile, jsonResult) + if err != nil { + fmt.Printf("Error writing AST to output location '%s'\n", outFile) + } else { + fmt.Printf("Converted AST for %s to %s \n", path, outFile) + } + } + resultErr <- err +} + func processRequest(out string, inputPath string) { if strings.HasSuffix(inputPath, ".go") { fileInfo, err := os.Stat(inputPath) @@ -47,6 +80,12 @@ func processRequest(out string, inputPath string) { return } } else { + concurrency := runtime.NumCPU() + var successCount int = 0 + var failCount int = 0 + resultErrChan := make(chan error) + sem := make(chan int, concurrency) + var totalSentForProcessing = 0 err := filepath.Walk(inputPath, func(path string, info os.FileInfo, err error) error { if err != nil { log.SetPrefix("[ERROR]") @@ -54,50 +93,24 @@ func processRequest(out string, inputPath string) { fmt.Printf("Error accessing path '%s'\n", path) return err } - if !info.IsDir() && strings.HasSuffix(info.Name(), ".go") { - var outFile = "" - directory := filepath.Dir(path) - if out == ".ast" { - outFile = filepath.Join(inputPath, out, strings.ReplaceAll(directory, inputPath, ""), info.Name()+".json") - } else { - outFile = filepath.Join(out, strings.ReplaceAll(directory, inputPath, ""), info.Name()+".json") - } - jsonResult, perr := goastgen.ParseAstFromFile(path) - if perr != nil { - fmt.Printf("Failed to generate AST for %s \n", path) - } else { - err = writeFileContents(outFile, jsonResult) - if err != nil { - fmt.Printf("Error writing AST to output location '%s'\n", outFile) - } else { - fmt.Printf("Converted AST for %s to %s \n", path, outFile) - } - return nil - } - } else if strings.HasSuffix(info.Name(), ".mod") { - var outFile = "" - directory := filepath.Dir(path) - if out == ".ast" { - outFile = filepath.Join(inputPath, out, strings.ReplaceAll(directory, inputPath, ""), info.Name()+".json") - } else { - outFile = filepath.Join(out, strings.ReplaceAll(directory, inputPath, ""), info.Name()+".json") - } - jsonResult, perr := goastgen.ParseModFromFile(path) - if perr != nil { - fmt.Printf("Failed to generate AST for %s \n", path) - } else { - err = writeFileContents(outFile, jsonResult) - if err != nil { - fmt.Printf("Error writing AST to output location '%s'\n", outFile) - } else { - fmt.Printf("Converted AST for %s to %s \n", path, outFile) - } - return nil - } + if !info.IsDir() && (strings.HasSuffix(info.Name(), ".go") || strings.HasSuffix(info.Name(), ".mod")) { + totalSentForProcessing++ + go processFile(out, inputPath, path, info, resultErrChan, sem) } return nil }) + for i := 0; i < totalSentForProcessing; i++ { + err = <-resultErrChan + if err != nil { + failCount++ + } else { + successCount++ + } + } + println("\n\n\n\n Without error -> ", successCount, ", With Error -> ", failCount) + println("total files sent for processing ----> ", totalSentForProcessing) + println("No of CPUs --->", concurrency) if err != nil { log.SetPrefix("[ERROR]") log.Printf("Error walking the path %s: %v\n", inputPath, err)