Skip to content

Commit

Permalink
gazelle: Populate plugins attributes with annotation processors
Browse files Browse the repository at this point in the history
  • Loading branch information
illicitonion committed May 8, 2024
1 parent f518e2e commit a8e904b
Show file tree
Hide file tree
Showing 13 changed files with 345 additions and 39 deletions.
17 changes: 17 additions & 0 deletions java/gazelle/configure.go
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,7 @@ import (
"github.com/bazel-contrib/rules_jvm/java/gazelle/javaconfig"
"github.com/bazel-contrib/rules_jvm/java/gazelle/private/javaparser"
"github.com/bazel-contrib/rules_jvm/java/gazelle/private/maven"
"github.com/bazel-contrib/rules_jvm/java/gazelle/private/types"
"github.com/bazelbuild/bazel-gazelle/config"
"github.com/bazelbuild/bazel-gazelle/rule"
bzl "github.com/bazelbuild/buildtools/build"
Expand Down Expand Up @@ -64,6 +65,7 @@ func (jc *Configurer) KnownDirectives() []string {
javaconfig.JavaTestMode,
javaconfig.JavaGenerateProto,
javaconfig.JavaMavenRepositoryName,
javaconfig.JavaAnnotationProcessorPlugin,
}
}

Expand Down Expand Up @@ -129,6 +131,21 @@ func (jc *Configurer) Configure(c *config.Config, rel string, f *rule.File) {
jc.lang.logger.Fatal().Msgf("invalid value for directive %q: %s: possible values are true/false",
javaconfig.JavaGenerateProto, d.Value)
}
case javaconfig.JavaAnnotationProcessorPlugin:
// Format: # gazelle:java_annotation_processor_plugin com.example.AnnotationName com.example.AnnotationProcessorImpl
parts := strings.Split(d.Value, " ")
if len(parts) != 2 {
jc.lang.logger.Fatal().Msgf("invalid value for directive %q: %s: expected an annotation class-name followed by a processor class-name", javaconfig.JavaAnnotationProcessorPlugin, d.Value)
}
annotationClassName, err := types.ParseClassName(parts[0])
if err != nil {
jc.lang.logger.Fatal().Msgf("invalid value for directive %q: %q: couldn't parse annotation processor annotation class-name: %v", javaconfig.JavaAnnotationProcessorPlugin, parts[0], err)
}
processorClassName, err := types.ParseClassName(parts[1])
if err != nil {
jc.lang.logger.Fatal().Msgf("invalid value for directive %q: %q: couldn't parse annotation processor class-name: %v", javaconfig.JavaAnnotationProcessorPlugin, parts[1], err)
}
cfg.AddAnnotationProcessorPlugin(*annotationClassName, *processorClassName)
}
}
}
Expand Down
23 changes: 16 additions & 7 deletions java/gazelle/generate.go
Original file line number Diff line number Diff line change
Expand Up @@ -186,13 +186,18 @@ func (l javaLang) GenerateRules(args language.GenerateArgs) language.GenerateRes
return true
})

annotationProcessorClasses := sorted_set.NewSortedSetFn(nil, types.ClassNameLess)
for _, annotationClass := range javaPkg.AllAnnotations().SortedSlice() {
annotationProcessorClasses.AddAll(cfg.GetAnnotationProcessorPluginClasses(annotationClass))
}

javaLibraryKind := "java_library"
if kindMap, ok := args.Config.KindMap["java_library"]; ok {
javaLibraryKind = kindMap.KindName
}

if productionJavaFiles.Len() > 0 {
l.generateJavaLibrary(args.File, args.Rel, filepath.Base(args.Rel), productionJavaFiles.SortedSlice(), allPackageNames, nonLocalProductionJavaImports, nonLocalJavaExports, false, javaLibraryKind, &res)
l.generateJavaLibrary(args.File, args.Rel, filepath.Base(args.Rel), productionJavaFiles.SortedSlice(), allPackageNames, nonLocalProductionJavaImports, nonLocalJavaExports, annotationProcessorClasses, false, javaLibraryKind, &res)
}

var testHelperJavaClasses *sorted_set.SortedSet[types.ClassName]
Expand Down Expand Up @@ -228,7 +233,7 @@ func (l javaLang) GenerateRules(args language.GenerateArgs) language.GenerateRes
testJavaImportsWithHelpers.Add(tf.pkg)
srcs = append(srcs, tf.pathRelativeToBazelWorkspaceRoot)
}
l.generateJavaLibrary(args.File, args.Rel, filepath.Base(args.Rel), srcs, packages, testJavaImports, nonLocalJavaExports, true, javaLibraryKind, &res)
l.generateJavaLibrary(args.File, args.Rel, filepath.Base(args.Rel), srcs, packages, testJavaImports, nonLocalJavaExports, annotationProcessorClasses, true, javaLibraryKind, &res)
}
}

Expand All @@ -240,7 +245,7 @@ func (l javaLang) GenerateRules(args language.GenerateArgs) language.GenerateRes
case "file":
for _, tf := range testJavaFiles.SortedSlice() {
separateJavaTestReasons := separateTestJavaFiles[tf]
l.generateJavaTest(args.File, args.Rel, cfg.MavenRepositoryName(), tf, isModule, testJavaImportsWithHelpers, nil, separateJavaTestReasons.wrapper, separateJavaTestReasons.attributes, &res)
l.generateJavaTest(args.File, args.Rel, cfg.MavenRepositoryName(), tf, isModule, testJavaImportsWithHelpers, annotationProcessorClasses, nil, separateJavaTestReasons.wrapper, separateJavaTestReasons.attributes, &res)
}

case "suite":
Expand Down Expand Up @@ -268,6 +273,7 @@ func (l javaLang) GenerateRules(args language.GenerateArgs) language.GenerateRes
packageNames,
cfg.MavenRepositoryName(),
testJavaImportsWithHelpers,
annotationProcessorClasses,
cfg.GetCustomJavaTestFileSuffixes(),
testHelperJavaFiles.Len() > 0,
&res,
Expand All @@ -284,7 +290,7 @@ func (l javaLang) GenerateRules(args language.GenerateArgs) language.GenerateRes
testHelperDep = ptr(testHelperLibname(suiteName))
}
separateJavaTestReasons := separateTestJavaFiles[src]
l.generateJavaTest(args.File, args.Rel, cfg.MavenRepositoryName(), src, isModule, testJavaImportsWithHelpers, testHelperDep, separateJavaTestReasons.wrapper, separateJavaTestReasons.attributes, &res)
l.generateJavaTest(args.File, args.Rel, cfg.MavenRepositoryName(), src, isModule, testJavaImportsWithHelpers, annotationProcessorClasses, testHelperDep, separateJavaTestReasons.wrapper, separateJavaTestReasons.attributes, &res)
}
}
}
Expand Down Expand Up @@ -453,7 +459,7 @@ func accumulateJavaFile(cfg *javaconfig.Config, testJavaFiles, testHelperJavaFil
}
}

func (l javaLang) generateJavaLibrary(file *rule.File, pathToPackageRelativeToBazelWorkspace string, name string, srcsRelativeToBazelWorkspace []string, packages, imports *sorted_set.SortedSet[types.PackageName], exports *sorted_set.SortedSet[types.PackageName], testonly bool, javaLibraryRuleKind string, res *language.GenerateResult) {
func (l javaLang) generateJavaLibrary(file *rule.File, pathToPackageRelativeToBazelWorkspace string, name string, srcsRelativeToBazelWorkspace []string, packages, imports *sorted_set.SortedSet[types.PackageName], exports *sorted_set.SortedSet[types.PackageName], annotationProcessorClasses *sorted_set.SortedSet[types.ClassName], testonly bool, javaLibraryRuleKind string, res *language.GenerateResult) {
const ruleKind = "java_library"
r := rule.NewRule(ruleKind, name)

Expand Down Expand Up @@ -487,6 +493,7 @@ func (l javaLang) generateJavaLibrary(file *rule.File, pathToPackageRelativeToBa
PackageNames: packages,
ImportedPackageNames: imports,
ExportedPackageNames: exports,
AnnotationProcessors: annotationProcessorClasses,
}
res.Imports = append(res.Imports, resolveInput)
}
Expand All @@ -511,7 +518,7 @@ func (l javaLang) generateJavaBinary(file *rule.File, m types.ClassName, libName
})
}

func (l javaLang) generateJavaTest(file *rule.File, pathToPackageRelativeToBazelWorkspace string, mavenRepositoryName string, f javaFile, includePackageInName bool, imports *sorted_set.SortedSet[types.PackageName], depOnTestHelpers *string, wrapper string, extraAttributes map[string]bzl.Expr, res *language.GenerateResult) {
func (l javaLang) generateJavaTest(file *rule.File, pathToPackageRelativeToBazelWorkspace string, mavenRepositoryName string, f javaFile, includePackageInName bool, imports *sorted_set.SortedSet[types.PackageName], annotationProcessorClasses *sorted_set.SortedSet[types.ClassName], depOnTestHelpers *string, wrapper string, extraAttributes map[string]bzl.Expr, res *language.GenerateResult) {
className := f.ClassName()
fullyQualifiedTestClass := className.FullyQualifiedClassName()
var testName string
Expand Down Expand Up @@ -571,6 +578,7 @@ func (l javaLang) generateJavaTest(file *rule.File, pathToPackageRelativeToBazel
resolveInput := types.ResolveInput{
PackageNames: sorted_set.NewSortedSetFn([]types.PackageName{f.pkg}, types.PackageNameLess),
ImportedPackageNames: testImports,
AnnotationProcessors: annotationProcessorClasses,
}
res.Imports = append(res.Imports, resolveInput)
}
Expand Down Expand Up @@ -598,7 +606,7 @@ var junit5RuntimeDeps = []string{
"org.junit.platform:junit-platform-reporting",
}

func (l javaLang) generateJavaTestSuite(file *rule.File, name string, srcs []string, packageNames *sorted_set.SortedSet[types.PackageName], mavenRepositoryName string, imports *sorted_set.SortedSet[types.PackageName], customTestSuffixes *[]string, hasHelpers bool, res *language.GenerateResult) {
func (l javaLang) generateJavaTestSuite(file *rule.File, name string, srcs []string, packageNames *sorted_set.SortedSet[types.PackageName], mavenRepositoryName string, imports *sorted_set.SortedSet[types.PackageName], annotationProcessorClasses *sorted_set.SortedSet[types.ClassName], customTestSuffixes *[]string, hasHelpers bool, res *language.GenerateResult) {
const ruleKind = "java_test_suite"
r := rule.NewRule(ruleKind, name)
r.SetAttr("srcs", srcs)
Expand Down Expand Up @@ -636,6 +644,7 @@ func (l javaLang) generateJavaTestSuite(file *rule.File, name string, srcs []str
resolveInput := types.ResolveInput{
PackageNames: packageNames,
ImportedPackageNames: suiteImports,
AnnotationProcessors: annotationProcessorClasses,
}
res.Imports = append(res.Imports, resolveInput)
}
Expand Down
4 changes: 2 additions & 2 deletions java/gazelle/generate_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -159,7 +159,7 @@ func TestSingleJavaTestFile(t *testing.T) {
var res language.GenerateResult

l := newTestJavaLang(t)
l.generateJavaTest(nil, "", "maven", f, tc.includePackageInName, stringsToPackageNames(tc.importedPackages), nil, tc.wrapper, nil, &res)
l.generateJavaTest(nil, "", "maven", f, tc.includePackageInName, stringsToPackageNames(tc.importedPackages), nil, nil, tc.wrapper, nil, &res)

require.Len(t, res.Gen, 1, "want 1 generated rule")

Expand Down Expand Up @@ -252,7 +252,7 @@ func TestSuite(t *testing.T) {
var res language.GenerateResult

l := newTestJavaLang(t)
l.generateJavaTestSuite(nil, "blah", []string{src}, stringsToPackageNames([]string{pkg}), "maven", stringsToPackageNames(tc.importedPackages), nil, false, &res)
l.generateJavaTestSuite(nil, "blah", []string{src}, stringsToPackageNames([]string{pkg}), "maven", stringsToPackageNames(tc.importedPackages), nil, nil, false, &res)

require.Len(t, res.Gen, 1, "want 1 generated rule")

Expand Down
2 changes: 2 additions & 0 deletions java/gazelle/javaconfig/BUILD.bazel
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,8 @@ go_library(
importpath = "github.com/bazel-contrib/rules_jvm/java/gazelle/javaconfig",
visibility = ["//visibility:public"],
deps = [
"//java/gazelle/private/sorted_set",
"//java/gazelle/private/types",
"@com_github_bazelbuild_buildtools//build",
],
)
Expand Down
49 changes: 37 additions & 12 deletions java/gazelle/javaconfig/config.go
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,8 @@ import (
"path/filepath"
"strings"

"github.com/bazel-contrib/rules_jvm/java/gazelle/private/sorted_set"
"github.com/bazel-contrib/rules_jvm/java/gazelle/private/types"
bzl "github.com/bazelbuild/buildtools/build"
)

Expand Down Expand Up @@ -47,6 +49,10 @@ const (
// JavaMavenRepositoryName tells the code generator what the repository name that contains all maven dependencies is.
// Defaults to "maven"
JavaMavenRepositoryName = "java_maven_repository_name"

// JavaAnnotationProcessorPlugin tells the code generator about specific java_plugin targets needed to process
// specific annotations.
JavaAnnotationProcessorPlugin = "java_annotation_processor_plugin"
)

// Configs is an extension of map[string]*Config. It provides finding methods
Expand All @@ -60,6 +66,10 @@ func (c *Config) NewChild() *Config {
for key, value := range c.excludedArtifacts {
clonedExcludedArtifacts[key] = value
}
annotationProcessorFullQualifiedClassToPluginClass := make(map[string]*sorted_set.SortedSet[types.ClassName])
for key, value := range c.annotationProcessorFullQualifiedClassToPluginClass {
annotationProcessorFullQualifiedClassToPluginClass[key] = value.Clone()
}
return &Config{
parent: c,
extensionEnabled: c.extensionEnabled,
Expand All @@ -74,6 +84,7 @@ func (c *Config) NewChild() *Config {
annotationToWrapper: c.annotationToWrapper,
excludedArtifacts: clonedExcludedArtifacts,
mavenRepositoryName: c.mavenRepositoryName,
annotationProcessorFullQualifiedClassToPluginClass: annotationProcessorFullQualifiedClassToPluginClass,
}
}

Expand All @@ -91,18 +102,19 @@ func (c *Configs) ParentForPackage(pkg string) *Config {
type Config struct {
parent *Config

extensionEnabled bool
isModuleRoot bool
generateProto bool
mavenInstallFile string
moduleGranularity string
repoRoot string
testMode string
customTestFileSuffixes *[]string
excludedArtifacts map[string]struct{}
annotationToAttribute map[string]map[string]bzl.Expr
annotationToWrapper map[string]string
mavenRepositoryName string
extensionEnabled bool
isModuleRoot bool
generateProto bool
mavenInstallFile string
moduleGranularity string
repoRoot string
testMode string
customTestFileSuffixes *[]string
excludedArtifacts map[string]struct{}
annotationToAttribute map[string]map[string]bzl.Expr
annotationToWrapper map[string]string
mavenRepositoryName string
annotationProcessorFullQualifiedClassToPluginClass map[string]*sorted_set.SortedSet[types.ClassName]
}

type LoadInfo struct {
Expand All @@ -125,6 +137,7 @@ func New(repoRoot string) *Config {
annotationToAttribute: make(map[string]map[string]bzl.Expr),
annotationToWrapper: make(map[string]string),
mavenRepositoryName: "maven",
annotationProcessorFullQualifiedClassToPluginClass: make(map[string]*sorted_set.SortedSet[types.ClassName]),
}
}

Expand Down Expand Up @@ -269,6 +282,18 @@ func (c *Config) IsTestRule(ruleKind string) bool {
return false
}

func (c *Config) GetAnnotationProcessorPluginClasses(annotationClass types.ClassName) *sorted_set.SortedSet[types.ClassName] {
return c.annotationProcessorFullQualifiedClassToPluginClass[annotationClass.FullyQualifiedClassName()]
}

func (c *Config) AddAnnotationProcessorPlugin(annotationClass types.ClassName, processorClass types.ClassName) {
fullyQualifiedAnnotationClass := annotationClass.FullyQualifiedClassName()
if _, ok := c.annotationProcessorFullQualifiedClassToPluginClass[fullyQualifiedAnnotationClass]; !ok {
c.annotationProcessorFullQualifiedClassToPluginClass[fullyQualifiedAnnotationClass] = sorted_set.NewSortedSetFn[types.ClassName](nil, types.ClassNameLess)
}
c.annotationProcessorFullQualifiedClassToPluginClass[fullyQualifiedAnnotationClass].Add(processorClass)
}

func equalStringSlices(l, r []string) bool {
if len(l) != len(r) {
return false
Expand Down
11 changes: 11 additions & 0 deletions java/gazelle/private/java/package.go
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,17 @@ type Package struct {
PerClassMetadata map[string]PerClassMetadata
}

func (p *Package) AllAnnotations() *sorted_set.SortedSet[types.ClassName] {
annotations := sorted_set.NewSortedSetFn(nil, types.ClassNameLess)
for _, pcm := range p.PerClassMetadata {
annotations.AddAll(pcm.AnnotationClassNames)
for _, method := range pcm.MethodAnnotationClassNames.Keys() {
annotations.AddAll(pcm.MethodAnnotationClassNames.Values(method))
}
}
return annotations
}

type PerClassMetadata struct {
AnnotationClassNames *sorted_set.SortedSet[types.ClassName]
MethodAnnotationClassNames *sorted_multiset.SortedMultiSet[string, types.ClassName]
Expand Down
1 change: 1 addition & 0 deletions java/gazelle/private/types/types.go
Original file line number Diff line number Diff line change
Expand Up @@ -97,6 +97,7 @@ type ResolveInput struct {
PackageNames *sorted_set.SortedSet[PackageName]
ImportedPackageNames *sorted_set.SortedSet[PackageName]
ExportedPackageNames *sorted_set.SortedSet[PackageName]
AnnotationProcessors *sorted_set.SortedSet[ClassName]
}

type ResolvableJavaPackage struct {
Expand Down
63 changes: 45 additions & 18 deletions java/gazelle/resolve.go
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@ import (
"errors"
"fmt"
"sort"
"strings"

"github.com/bazel-contrib/rules_jvm/java/gazelle/javaconfig"
"github.com/bazel-contrib/rules_jvm/java/gazelle/private/java"
Expand Down Expand Up @@ -97,19 +98,13 @@ func (jr *Resolver) Resolve(c *config.Config, ix *resolve.RuleIndex, rc *repo.Re

jr.populateAttr(c, packageConfig, r, "deps", resolveInput.ImportedPackageNames, ix, isTestRule, from, resolveInput.PackageNames)
jr.populateAttr(c, packageConfig, r, "exports", resolveInput.ExportedPackageNames, ix, isTestRule, from, resolveInput.PackageNames)

jr.populatePluginsAttr(c, ix, resolveInput, packageConfig, from, isTestRule, r)
}

func (jr *Resolver) populateAttr(c *config.Config, pc *javaconfig.Config, r *rule.Rule, attrName string, requiredPackageNames *sorted_set.SortedSet[types.PackageName], ix *resolve.RuleIndex, isTestRule bool, from label.Label, ownPackageNames *sorted_set.SortedSet[types.PackageName]) {
labels := sorted_set.NewSortedSetFn[label.Label]([]label.Label{}, labelLess)

for _, implicitDep := range r.AttrStrings(attrName) {
l, err := label.Parse(implicitDep)
if err != nil {
panic(fmt.Sprintf("error converting implicit %s %q to label: %v", attrName, implicitDep, err))
}
labels.Add(l)
}

for _, imp := range requiredPackageNames.SortedSlice() {
dep := jr.resolveSinglePackage(c, pc, imp, ix, from, isTestRule, ownPackageNames)
if dep == label.NoLabel {
Expand All @@ -119,18 +114,26 @@ func (jr *Resolver) populateAttr(c *config.Config, pc *javaconfig.Config, r *rul
labels.Add(simplifyLabel(c.RepoName, dep, from))
}

var exprs []build.Expr
if labels.Len() > 0 {
for _, l := range labels.SortedSlice() {
if l.Relative && l.Name == from.Name {
continue
}
exprs = append(exprs, &build.StringExpr{Value: l.String()})
setLabelAttrIncludingExistingValues(r, attrName, labels)
}

func (jr *Resolver) populatePluginsAttr(c *config.Config, ix *resolve.RuleIndex, resolveInput types.ResolveInput, packageConfig *javaconfig.Config, from label.Label, isTestRule bool, r *rule.Rule) {
pluginLabels := sorted_set.NewSortedSetFn[label.Label]([]label.Label{}, labelLess)
for _, annotationProcessor := range resolveInput.AnnotationProcessors.SortedSlice() {
dep := jr.resolveSinglePackage(c, packageConfig, annotationProcessor.PackageName(), ix, from, isTestRule, resolveInput.PackageNames)
if dep == label.NoLabel {
continue
}

// Use the naming scheme for plugins as per https://github.com/bazelbuild/rules_jvm_external/pull/1102
// In the case of overrides (i.e. # gazelle:resolve targets) we require that they follow the same name-mangling scheme for the java_plugin target as rules_jvm_external uses.
// Ideally this would be a call to `java_plugin_artifact(dep.String(), annotationProcessor.FullyQualifiedClassName())` but we don't have function calls working in attributes.
dep.Name += "__java_plugin__" + strings.NewReplacer(".", "_", "$", "_").Replace(annotationProcessor.FullyQualifiedClassName())

pluginLabels.Add(simplifyLabel(c.RepoName, dep, from))
}
if len(exprs) > 0 {
r.SetAttr(attrName, exprs)
}

setLabelAttrIncludingExistingValues(r, "plugins", pluginLabels)
}

func labelLess(l, r label.Label) bool {
Expand Down Expand Up @@ -159,6 +162,30 @@ func simplifyLabel(repoName string, l label.Label, from label.Label) label.Label
return l
}

// Note: This function may modify labels.
func setLabelAttrIncludingExistingValues(r *rule.Rule, attrName string, labels *sorted_set.SortedSet[label.Label]) {
for _, implicitDep := range r.AttrStrings(attrName) {
l, err := label.Parse(implicitDep)
if err != nil {
panic(fmt.Sprintf("error converting implicit %s %q to label: %v", attrName, implicitDep, err))
}
labels.Add(l)
}

var exprs []build.Expr
if labels.Len() > 0 {
for _, l := range labels.SortedSlice() {
if l.Relative && l.Name == r.Name() {
continue
}
exprs = append(exprs, &build.StringExpr{Value: l.String()})
}
}
if len(exprs) > 0 {
r.SetAttr(attrName, exprs)
}
}

func (jr *Resolver) resolveSinglePackage(c *config.Config, pc *javaconfig.Config, imp types.PackageName, ix *resolve.RuleIndex, from label.Label, isTestRule bool, ownPackageNames *sorted_set.SortedSet[types.PackageName]) (out label.Label) {
cacheKey := types.NewResolvableJavaPackage(imp, false, false)
importSpec := resolve.ImportSpec{Lang: languageName, Imp: cacheKey.String()}
Expand Down
1 change: 1 addition & 0 deletions java/gazelle/testdata/annotation_processor/BUILD.in
Original file line number Diff line number Diff line change
@@ -0,0 +1 @@
# gazelle:java_annotation_processor_plugin com.google.auto.value.AutoValue com.google.auto.value.processor.AutoValueProcessor
Loading

0 comments on commit a8e904b

Please sign in to comment.