Skip to content

Commit

Permalink
impl(language): remove unnecessary dependency on Rust Codec
Browse files Browse the repository at this point in the history
Remove dependency on the Rust codec in places where it is not required
by converting those methods to functions with a `rust` prefix.

Also remove some unnecessary functions and arguments.
  • Loading branch information
julieqiu committed Jan 12, 2025
1 parent 137125d commit 5fbe9f1
Show file tree
Hide file tree
Showing 4 changed files with 107 additions and 121 deletions.
2 changes: 1 addition & 1 deletion generator/internal/language/client.go
Original file line number Diff line number Diff line change
Expand Up @@ -49,7 +49,7 @@ func GenerateClient(model *api.API, language, outdir string, options map[string]
return err
}
data = newRustTemplateData(model, codec)
provider = codec.templatesProvider()
provider = rustTemplatesProvider()
generatedFiles = codec.generatedFiles()
case "go":
var err error
Expand Down
132 changes: 60 additions & 72 deletions generator/internal/language/rust.go
Original file line number Diff line number Diff line change
Expand Up @@ -324,7 +324,7 @@ func scalarFieldType(f *api.Field) string {
return out
}

func (c *rustCodec) fieldFormatter(typez api.Typez) string {
func rustFieldFormatter(typez api.Typez) string {
switch typez {
case api.INT64_TYPE,
api.UINT64_TYPE,
Expand All @@ -339,7 +339,7 @@ func (c *rustCodec) fieldFormatter(typez api.Typez) string {
}
}

func (c *rustCodec) fieldSkipAttributes(f *api.Field) []string {
func rustFieldSkipAttributes(f *api.Field) []string {
switch f.Typez {
case api.STRING_TYPE:
return []string{`#[serde(skip_serializing_if = "String::is_empty")]`}
Expand All @@ -350,14 +350,14 @@ func (c *rustCodec) fieldSkipAttributes(f *api.Field) []string {
}
}

func (c *rustCodec) fieldBaseAttributes(f *api.Field) []string {
if c.toCamel(c.toSnake(f.Name)) != f.JSONName {
func rustFieldBaseAttributes(f *api.Field) []string {
if rustToCamel(rustToSnake(f.Name)) != f.JSONName {
return []string{fmt.Sprintf(`#[serde(rename = "%s")]`, f.JSONName)}
}
return []string{}
}

func (c *rustCodec) wrapperFieldAttributes(f *api.Field, attributes []string) []string {
func rustWrapperFieldAttributes(f *api.Field, attributes []string) []string {
// Message fields could be `Vec<..>`, and are always optional:
if f.Optional {
attributes = append(attributes, `#[serde(skip_serializing_if = "Option::is_none")]`)
Expand All @@ -368,11 +368,11 @@ func (c *rustCodec) wrapperFieldAttributes(f *api.Field, attributes []string) []
var formatter string
switch f.TypezID {
case ".google.protobuf.BytesValue":
formatter = c.fieldFormatter(api.BYTES_TYPE)
formatter = rustFieldFormatter(api.BYTES_TYPE)
case ".google.protobuf.UInt64Value":
formatter = c.fieldFormatter(api.UINT64_TYPE)
formatter = rustFieldFormatter(api.UINT64_TYPE)
case ".google.protobuf.Int64Value":
formatter = c.fieldFormatter(api.INT64_TYPE)
formatter = rustFieldFormatter(api.INT64_TYPE)
default:
return attributes
}
Expand All @@ -383,11 +383,11 @@ func (c *rustCodec) wrapperFieldAttributes(f *api.Field, attributes []string) []
fmt.Sprintf(`#[serde_as(as = "Option<%s>")]`, formatter))
}

func (c *rustCodec) fieldAttributes(f *api.Field, state *api.APIState) []string {
func rustFieldAttributes(f *api.Field, state *api.APIState) []string {
if f.Synthetic {
return []string{`#[serde(skip)]`}
}
attributes := c.fieldBaseAttributes(f)
attributes := rustFieldBaseAttributes(f)
switch f.Typez {
case api.DOUBLE_TYPE,
api.FLOAT_TYPE,
Expand All @@ -406,15 +406,15 @@ func (c *rustCodec) fieldAttributes(f *api.Field, state *api.APIState) []string
if f.Repeated {
return append(attributes, `#[serde(skip_serializing_if = "Vec::is_empty")]`)
}
return append(attributes, c.fieldSkipAttributes(f)...)
return append(attributes, rustFieldSkipAttributes(f)...)

case api.INT64_TYPE,
api.UINT64_TYPE,
api.FIXED64_TYPE,
api.SFIXED64_TYPE,
api.SINT64_TYPE,
api.BYTES_TYPE:
formatter := c.fieldFormatter(f.Typez)
formatter := rustFieldFormatter(f.Typez)
if f.Optional {
attributes = append(attributes, `#[serde(skip_serializing_if = "Option::is_none")]`)
return append(attributes, fmt.Sprintf(`#[serde_as(as = "Option<%s>")]`, formatter))
Expand All @@ -423,7 +423,7 @@ func (c *rustCodec) fieldAttributes(f *api.Field, state *api.APIState) []string
attributes = append(attributes, `#[serde(skip_serializing_if = "Vec::is_empty")]`)
return append(attributes, fmt.Sprintf(`#[serde_as(as = "Vec<%s>")]`, formatter))
}
attributes = append(attributes, c.fieldSkipAttributes(f)...)
attributes = append(attributes, rustFieldSkipAttributes(f)...)
return append(attributes, fmt.Sprintf(`#[serde_as(as = "%s")]`, formatter))

case api.MESSAGE_TYPE:
Expand All @@ -444,14 +444,14 @@ func (c *rustCodec) fieldAttributes(f *api.Field, state *api.APIState) []string
slog.Error("missing key or value in map field")
return attributes
}
keyFormat := c.fieldFormatter(key.Typez)
valFormat := c.fieldFormatter(value.Typez)
keyFormat := rustFieldFormatter(key.Typez)
valFormat := rustFieldFormatter(value.Typez)
if keyFormat == "_" && valFormat == "_" {
return attributes
}
return append(attributes, fmt.Sprintf(`#[serde_as(as = "std::collections::HashMap<%s, %s>")]`, keyFormat, valFormat))
}
return c.wrapperFieldAttributes(f, attributes)
return rustWrapperFieldAttributes(f, attributes)

default:
slog.Error("unexpected field type", "field", *f)
Expand Down Expand Up @@ -485,7 +485,7 @@ func (c *rustCodec) baseFieldType(f *api.Field, state *api.APIState) string {
val := c.fieldType(m.Fields[1], state, false)
return "std::collections::HashMap<" + key + "," + val + ">"
}
return c.fqMessageName(m, state)
return c.fqMessageName(m)
} else if f.Typez == api.ENUM_TYPE {
e, ok := state.EnumByID[f.TypezID]
if !ok {
Expand All @@ -501,19 +501,19 @@ func (c *rustCodec) baseFieldType(f *api.Field, state *api.APIState) string {

}

func (c *rustCodec) asQueryParameter(f *api.Field) string {
func rustAsQueryParameter(f *api.Field) string {
if f.Typez == api.MESSAGE_TYPE {
// Query parameters in nested messages are first converted to a
// `serde_json::Value`` and then recursively merged into the request
// query. The conversion to `serde_json::Value` is expensive, but very
// few requests use nested objects as query parameters. Furthermore,
// the conversion is skipped if the object field is `None`.`
return fmt.Sprintf("&serde_json::to_value(&req.%s).map_err(Error::serde)?", c.toSnake(f.Name))
return fmt.Sprintf("&serde_json::to_value(&req.%s).map_err(Error::serde)?", rustToSnake(f.Name))
}
return fmt.Sprintf("&req.%s", c.toSnake(f.Name))
return fmt.Sprintf("&req.%s", rustToSnake(f.Name))
}

func (c *rustCodec) templatesProvider() templateProvider {
func rustTemplatesProvider() templateProvider {
return func(name string) (string, error) {
contents, err := rustTemplates.ReadFile(name)
if err != nil {
Expand Down Expand Up @@ -545,7 +545,7 @@ func (c *rustCodec) methodInOutTypeName(id string, state *api.APIState) string {
slog.Error("unable to lookup type", "id", id)
return ""
}
return c.fqMessageName(m, state)
return c.fqMessageName(m)
}

func (c *rustCodec) rustPackage(packageName string) string {
Expand Down Expand Up @@ -576,59 +576,59 @@ func (c *rustCodec) messageAttributes(*api.Message, *api.APIState) []string {
}
}

func (c *rustCodec) messageName(m *api.Message) string {
return c.toPascal(m.Name)
func rustMessageName(m *api.Message) string {
return rustToPascal(m.Name)
}

func (c *rustCodec) messageScopeName(m *api.Message, childPackageName string) string {
if m == nil {
return c.rustPackage(childPackageName)
}
if m.Parent == nil {
return c.rustPackage(m.Package) + "::" + c.toSnake(m.Name)
return c.rustPackage(m.Package) + "::" + rustToSnake(m.Name)
}
return c.messageScopeName(m.Parent, m.Package) + "::" + c.toSnake(m.Name)
return c.messageScopeName(m.Parent, m.Package) + "::" + rustToSnake(m.Name)
}

func (c *rustCodec) enumScopeName(e *api.Enum) string {
return c.messageScopeName(e.Parent, e.Package)
}

func (c *rustCodec) fqMessageName(m *api.Message, _ *api.APIState) string {
return c.messageScopeName(m.Parent, m.Package) + "::" + c.toPascal(m.Name)
func (c *rustCodec) fqMessageName(m *api.Message) string {
return c.messageScopeName(m.Parent, m.Package) + "::" + rustToPascal(m.Name)
}

func (c *rustCodec) enumName(e *api.Enum) string {
return c.toPascal(e.Name)
func rustEnumName(e *api.Enum) string {
return rustToPascal(e.Name)
}

func (c *rustCodec) fqEnumName(e *api.Enum) string {
return c.messageScopeName(e.Parent, e.Package) + "::" + c.toPascal(e.Name)
return c.messageScopeName(e.Parent, e.Package) + "::" + rustToPascal(e.Name)
}

func (c *rustCodec) enumValueName(e *api.EnumValue, _ *api.APIState) string {
func rustEnumValueName(e *api.EnumValue) string {
// The Protobuf naming convention is to use SCREAMING_SNAKE_CASE, we do not
// need to change anything for Rust
return rustEscapeKeyword(e.Name)
}

func (c *rustCodec) fqEnumValueName(v *api.EnumValue, state *api.APIState) string {
return fmt.Sprintf("%s::%s::%s", c.enumScopeName(v.Parent), c.toSnake(v.Parent.Name), c.enumValueName(v, state))
func (c *rustCodec) fqEnumValueName(v *api.EnumValue) string {
return fmt.Sprintf("%s::%s::%s", c.enumScopeName(v.Parent), rustToSnake(v.Parent.Name), rustEnumValueName(v))
}

func (c *rustCodec) oneOfType(o *api.OneOf, _ *api.APIState) string {
return c.messageScopeName(o.Parent, "") + "::" + c.toPascal(o.Name)
func (c *rustCodec) oneOfType(o *api.OneOf) string {
return c.messageScopeName(o.Parent, "") + "::" + rustToPascal(o.Name)
}

func (c *rustCodec) bodyAccessor(m *api.Method) string {
func rustBodyAccessor(m *api.Method) string {
if m.PathInfo.BodyFieldPath == "*" {
// no accessor needed, use the whole request
return ""
}
return "." + c.toSnake(m.PathInfo.BodyFieldPath)
return "." + rustToSnake(m.PathInfo.BodyFieldPath)
}

func (c *rustCodec) httpPathFmt(m *api.PathInfo) string {
func rustHTTPPathFmt(m *api.PathInfo) string {
fmt := ""
for _, segment := range m.PathTemplate {
if segment.Literal != nil {
Expand Down Expand Up @@ -670,26 +670,26 @@ func (c *rustCodec) httpPathFmt(m *api.PathInfo) string {
// ```
//
// and so on.
func (c *rustCodec) unwrapFieldPath(components []string, requestAccess string) (string, string) {
func rustUnwrapFieldPath(components []string, requestAccess string) (string, string) {
if len(components) == 1 {
return requestAccess + "." + c.toSnake(components[0]), components[0]
return requestAccess + "." + rustToSnake(components[0]), components[0]
}
unwrap, name := c.unwrapFieldPath(components[0:len(components)-1], "&req")
unwrap, name := rustUnwrapFieldPath(components[0:len(components)-1], "&req")
last := components[len(components)-1]
return fmt.Sprintf("gax::path_parameter::PathParameter::required(%s, \"%s\").map_err(Error::other)?.%s", unwrap, name, last), ""
}

func (c *rustCodec) derefFieldPath(fieldPath string) string {
func rustDerefFieldPath(fieldPath string) string {
components := strings.Split(fieldPath, ".")
unwrap, _ := c.unwrapFieldPath(components, "req")
unwrap, _ := rustUnwrapFieldPath(components, "req")
return unwrap
}

func (c *rustCodec) httpPathArgs(h *api.PathInfo) []string {
func rustHTTPPathArgs(h *api.PathInfo) []string {
var args []string
for _, arg := range h.PathTemplate {
if arg.FieldPath != nil {
args = append(args, c.derefFieldPath(*arg.FieldPath))
args = append(args, rustDerefFieldPath(*arg.FieldPath))
}
}
return args
Expand All @@ -701,11 +701,11 @@ func (c *rustCodec) httpPathArgs(h *api.PathInfo) []string {
// This type of conversion can easily introduce keywords. Consider
//
// `toSnake("True") -> "true"`
func (c *rustCodec) toSnake(symbol string) string {
return rustEscapeKeyword(c.toSnakeNoMangling(symbol))
func rustToSnake(symbol string) string {
return rustEscapeKeyword(rustToSnakeNoMangling(symbol))
}

func (*rustCodec) toSnakeNoMangling(symbol string) string {
func rustToSnakeNoMangling(symbol string) string {
if strings.ToLower(symbol) == symbol {
return symbol
}
Expand All @@ -719,7 +719,7 @@ func (*rustCodec) toSnakeNoMangling(symbol string) string {
// This type of conversion rarely introduces keywords. The one example is
//
// `toPascal("self") -> "Self"`
func (*rustCodec) toPascal(symbol string) string {
func rustToPascal(symbol string) string {
if symbol == "" {
return ""
}
Expand All @@ -730,7 +730,7 @@ func (*rustCodec) toPascal(symbol string) string {
return rustEscapeKeyword(strcase.ToCamel(symbol))
}

func (*rustCodec) toCamel(symbol string) string {
func rustToCamel(symbol string) string {
return rustEscapeKeyword(strcase.ToLowerCamel(symbol))
}

Expand Down Expand Up @@ -863,7 +863,7 @@ func (c *rustCodec) rustdocLink(link string, state *api.APIState) string {
id := fmt.Sprintf(".%s", link)
m, ok := state.MessageByID[id]
if ok {
return c.fqMessageName(m, state)
return c.fqMessageName(m)
}
e, ok := state.EnumByID[id]
if ok {
Expand Down Expand Up @@ -902,20 +902,20 @@ func (c *rustCodec) tryFieldRustdocLink(id string, state *api.APIState) string {
for _, f := range m.Fields {
if f.Name == fieldName {
if !f.IsOneOf {
return fmt.Sprintf("%s::%s", c.fqMessageName(m, state), c.toSnake(f.Name))
return fmt.Sprintf("%s::%s", c.fqMessageName(m), rustToSnake(f.Name))
} else {
return c.tryOneOfRustdocLink(f, m, state)
return c.tryOneOfRustdocLink(f, m)
}
}
}
return ""
}

func (c *rustCodec) tryOneOfRustdocLink(field *api.Field, message *api.Message, state *api.APIState) string {
func (c *rustCodec) tryOneOfRustdocLink(field *api.Field, message *api.Message) string {
for _, o := range message.OneOfs {
for _, f := range o.Fields {
if f.ID == field.ID {
return fmt.Sprintf("%s::%s", c.fqMessageName(message, state), c.toSnake(o.Name))
return fmt.Sprintf("%s::%s", c.fqMessageName(message), rustToSnake(o.Name))
}
}
}
Expand All @@ -935,7 +935,7 @@ func (c *rustCodec) tryEnumValueRustdocLink(id string, state *api.APIState) stri
}
for _, v := range e.Values {
if v.Name == valueName {
return c.fqEnumValueName(v, state)
return c.fqEnumValueName(v)
}
}
return ""
Expand All @@ -944,7 +944,7 @@ func (c *rustCodec) tryEnumValueRustdocLink(id string, state *api.APIState) stri
func (c *rustCodec) methodRustdocLink(m *api.Method, state *api.APIState) string {
// Sometimes we remove methods from a service. In that case we cannot
// reference the method.
if !c.generateMethod(m) {
if !rustGenerateMethod(m) {
return ""
}
idx := strings.LastIndex(m.ID, ".")
Expand All @@ -956,7 +956,7 @@ func (c *rustCodec) methodRustdocLink(m *api.Method, state *api.APIState) string
if !ok {
return ""
}
return fmt.Sprintf("%s::%s", c.serviceRustdocLink(s, state), c.toSnake(m.Name))
return fmt.Sprintf("%s::%s", c.serviceRustdocLink(s, state), rustToSnake(m.Name))
}

func (c *rustCodec) serviceRustdocLink(s *api.Service, _ *api.APIState) string {
Expand Down Expand Up @@ -1018,14 +1018,6 @@ func (c *rustCodec) requiredPackages() []string {
return lines
}

func (c *rustCodec) copyrightYear() string {
return c.generationYear
}

func (c *rustCodec) packageVersion() string {
return c.version
}

func (c *rustCodec) packageName(api *api.API) string {
if len(c.packageNameOverride) > 0 {
return c.packageNameOverride
Expand Down Expand Up @@ -1107,11 +1099,7 @@ func (c *rustCodec) addStreamingFeature(data *RustTemplateData, api *api.API) {
data.HasFeatures = true
}

func (c *rustCodec) notForPublication() bool {
return c.doNotPublish
}

func (c *rustCodec) generateMethod(m *api.Method) bool {
func rustGenerateMethod(m *api.Method) bool {
// Ignore methods without HTTP annotations, we cannot generate working
// RPCs for them.
// TODO(#499) - switch to explicitly excluding such functions. Easier to
Expand Down
Loading

0 comments on commit 5fbe9f1

Please sign in to comment.