From fd162cfaec7b4028ddbde888e9b1ea4eed294928 Mon Sep 17 00:00:00 2001 From: Troy Sornson Date: Wed, 23 Oct 2024 21:55:22 -0600 Subject: [PATCH] Add apply-types command to compiler Implement typer writing (inefficiently) Implement formatter, support multiple filename inputs Partial support for types, still need modules Update to latest source typer changes Add code comments Comment out broken spec for now Uncomment and fix final spec, add semantic / progress_tracker flags Remove focus: true (oops) Good ol' print debugging for windows CI failure Reimplement def visitor def locator matching for windows Back to print debugging Fix and support windows drive letters for root folders --- spec/compiler/apply_types_spec.cr | 452 ++++++++++++++++++++++ src/compiler/crystal/command.cr | 4 + src/compiler/crystal/command/typer.cr | 92 +++++ src/compiler/crystal/tools/typer.cr | 537 ++++++++++++++++++++++++++ 4 files changed, 1085 insertions(+) create mode 100644 spec/compiler/apply_types_spec.cr create mode 100644 src/compiler/crystal/command/typer.cr create mode 100644 src/compiler/crystal/tools/typer.cr diff --git a/spec/compiler/apply_types_spec.cr b/spec/compiler/apply_types_spec.cr new file mode 100644 index 000000000000..2cc0b407bb97 --- /dev/null +++ b/spec/compiler/apply_types_spec.cr @@ -0,0 +1,452 @@ +require "./spec_helper" + +def run_source_typer_spec(input, expected_output, + splats : Bool = true, + line_number : Int32 = 1, + named_splats : Bool = true, + blocks : Bool = true, + prelude : String = "") + entrypoint_file = File.expand_path("entrypoint.cr") + locator = line_number > 0 ? "#{entrypoint_file}:#{line_number}" : entrypoint_file + typer = Crystal::SourceTyper.new(entrypoint_file, [locator], blocks, splats, named_splats, prelude) + + typer.semantic(entrypoint_file, input) + + typer.files.to_a.should eq [entrypoint_file] + result = typer.type_source(entrypoint_file, input) + result.should_not be_nil + not_nil_result = result.not_nil!("Why is this failing???") + not_nil_result.strip.should eq expected_output +end + +describe Crystal::SourceTyper do + it "types method return types" do + run_source_typer_spec(<<-INPUT, <<-OUTPUT) + def hello + "world!" + end + + hello + INPUT + def hello : String + "world!" + end + + hello + OUTPUT + end + + it "types positional arguments" do + run_source_typer_spec(<<-INPUT, <<-OUTPUT) + def hello(arg) + arg + end + hello("world") + INPUT + def hello(arg : String) : String + arg + end + + hello("world") + OUTPUT + end + + it "types positional args with unions" do + run_source_typer_spec(<<-INPUT, <<-OUTPUT) + def hello(arg) + arg + end + hello("world") + hello(3) + INPUT + def hello(arg : String | Int32) : String | Int32 + arg + end + + hello("world") + hello(3) + OUTPUT + end + + it "types splats, single type" do + run_source_typer_spec(<<-INPUT, <<-OUTPUT) + def hello(*arg) + nil + end + hello("world") + INPUT + def hello(*arg : String) : Nil + nil + end + + hello("world") + OUTPUT + end + + it "types splats, multiple calls with single type" do + run_source_typer_spec(<<-INPUT, <<-OUTPUT) + def hello(*arg) + nil + end + hello("world") + hello(3) + INPUT + def hello(*arg : String | Int32) : Nil + nil + end + + hello("world") + hello(3) + OUTPUT + end + + it "types splats, one call with multiple types" do + run_source_typer_spec(<<-INPUT, <<-OUTPUT) + def hello(*arg) + nil + end + hello("world", 3) + INPUT + def hello(*arg : String | Int32) : Nil + nil + end + + hello("world", 3) + OUTPUT + end + + it "types arguments but not splats" do + run_source_typer_spec(<<-INPUT, <<-OUTPUT, splats: false) + def hello(the_arg, *arg) + nil + end + hello(2, "world", 3) + INPUT + def hello(the_arg : Int32, *arg) : Nil + nil + end + + hello(2, "world", 3) + OUTPUT + end + + it "doesn't type splats with empty call" do + run_source_typer_spec(<<-INPUT, <<-OUTPUT) + def hello(*arg) + nil + end + hello("world") + hello + INPUT + def hello(*arg) : Nil + nil + end + + hello("world") + hello + OUTPUT + end + + it "types double splats, single type" do + run_source_typer_spec(<<-INPUT, <<-OUTPUT) + def hello(**args) + nil + end + hello(hello: "world") + INPUT + def hello(**args : String) : Nil + nil + end + + hello(hello: "world") + OUTPUT + end + + it "types double splats, multiple types" do + run_source_typer_spec(<<-INPUT, <<-OUTPUT) + def hello(**args) + nil + end + hello(hello: "world", world: 3) + INPUT + def hello(**args : String | Int32) : Nil + nil + end + + hello(hello: "world", world: 3) + OUTPUT + end + + it "types arguments but not double splats" do + run_source_typer_spec(<<-INPUT, <<-OUTPUT, named_splats: false) + def hello(the_arg, **args) + nil + end + hello(2, hello: "world", world: 3) + INPUT + def hello(the_arg : Int32, **args) : Nil + nil + end + + hello(2, hello: "world", world: 3) + OUTPUT + end + + it "types splats but not double splats" do + run_source_typer_spec(<<-INPUT, <<-OUTPUT, named_splats: false) + def hello(*arg, **args) + nil + end + hello(2, hello: "world", world: 3) + INPUT + def hello(*arg : Int32, **args) : Nil + nil + end + + hello(2, hello: "world", world: 3) + OUTPUT + end + + it "types double plats but not splats" do + run_source_typer_spec(<<-INPUT, <<-OUTPUT, splats: false) + def hello(*arg, **args) + nil + end + hello(2, hello: "world", world: 3) + INPUT + def hello(*arg, **args : String | Int32) : Nil + nil + end + + hello(2, hello: "world", world: 3) + OUTPUT + end + + it "doesn't type double splat with empty call" do + run_source_typer_spec(<<-INPUT, <<-OUTPUT) + def hello(**args) + nil + end + hello(hello: "world", world: 3) + hello + INPUT + def hello(**args) : Nil + nil + end + + hello(hello: "world", world: 3) + hello + OUTPUT + end + + it "types blocks" do + run_source_typer_spec(<<-INPUT, <<-OUTPUT) + def hello(&block) + block + end + hello {} + INPUT + def hello(&block : Proc(Nil)) : Proc(Nil) + block + end + + hello { } + OUTPUT + end + + it "types class instance methods" do + run_source_typer_spec(<<-INPUT, <<-OUTPUT, line_number: 2) + class Test + def hello(arg) + arg + end + end + Test.new.hello(3) + INPUT + class Test + def hello(arg : Int32) : Int32 + arg + end + end + + Test.new.hello(3) + OUTPUT + end + + it "types class methods" do + run_source_typer_spec(<<-INPUT, <<-OUTPUT, line_number: 2) + class Test + def self.hello(arg) + arg + end + end + Test.hello(3) + INPUT + class Test + def self.hello(arg : Int32) : Int32 + arg + end + end + + Test.hello(3) + OUTPUT + end + + it "types method included from module" do + run_source_typer_spec(<<-INPUT, <<-OUTPUT, line_number: 2) + module IncludeMe + def hello(arg) + arg + end + end + + class Test + include IncludeMe + end + Test.new.hello(3) + INPUT + module IncludeMe + def hello(arg : Int32) : Int32 + arg + end + end + + class Test + include IncludeMe + end + + Test.new.hello(3) + OUTPUT + end + + it "doesn't remove newline when inserting return types" do + run_source_typer_spec(<<-INPUT, <<-OUTPUT) + def hello + # world + "world" + end + hello + INPUT + def hello : String + # world + "world" + end + + hello + OUTPUT + end + + it "turns unions with nil to have a '?' suffix" do + run_source_typer_spec(<<-INPUT, <<-OUTPUT) + def hello(arg) + nil + end + hello(nil) + hello("world") + INPUT + def hello(arg : String?) : Nil + nil + end + + hello(nil) + hello("world") + OUTPUT + end + + it "runs prelude and types everything" do + run_source_typer_spec(<<-INPUT, <<-OUTPUT, line_number: -1, prelude: "prelude") + # This file tries to capture each type of definition format + def hello + "world" + end + + def hello1(arg1) + arg1 + end + + def hello2(arg1, *, arg2) + arg1 + arg2 + end + + def hello3(&block) + block.call + end + + def hello4(*args) + args[0]? + end + + def hello5(**args) + nil + end + + class Test + def hello + "world" + end + + def self.hello + "world" + end + end + + hello + hello1("world") + hello2(1, arg2: 2) + hello3 do + "hello" + end + hello4(3, "ok") + hello5(test: "test", other: 3) + Test.hello + Test.new.hello + + INPUT + # This file tries to capture each type of definition format + def hello : String + "world" + end + + def hello1(arg1 : String) : String + arg1 + end + + def hello2(arg1 : Int32, *, arg2 : Int32) : Int32 + arg1 + arg2 + end + + def hello3(&block : Proc(Nil)) : Nil + block.call + end + + def hello4(*args : Int32 | String) : Int32 + args[0]? + end + + def hello5(**args : String | Int32) : Nil + nil + end + + class Test + def hello : String + "world" + end + + def self.hello : String + "world" + end + end + + hello + hello1("world") + hello2(1, arg2: 2) + hello3 do + "hello" + end + hello4(3, "ok") + hello5(test: "test", other: 3) + Test.hello + Test.new.hello + OUTPUT + end +end diff --git a/src/compiler/crystal/command.cr b/src/compiler/crystal/command.cr index 571c965352e0..c7b8021b104b 100644 --- a/src/compiler/crystal/command.cr +++ b/src/compiler/crystal/command.cr @@ -46,6 +46,7 @@ class Crystal::Command format format project, directories and/or files hierarchy show type hierarchy implementations show implementations for given call in location + apply-types add types to all untyped defs and def arguments types show type of main variables unreachable show methods that are never called --help, -h show this help @@ -196,6 +197,9 @@ class Crystal::Command when "implementations".starts_with?(tool) options.shift implementations + when "apply-types".starts_with?(tool) + options.shift + typer when "types".starts_with?(tool) options.shift types diff --git a/src/compiler/crystal/command/typer.cr b/src/compiler/crystal/command/typer.cr new file mode 100644 index 000000000000..f467b9a1d46c --- /dev/null +++ b/src/compiler/crystal/command/typer.cr @@ -0,0 +1,92 @@ +# Implementation of the `crystal tool format` command +# +# This is just the command-line part. The formatter +# logic is in `crystal/tools/formatter.cr`. + +class Crystal::Command + private def typer + prelude = "prelude" + type_blocks = false + type_splats = false + type_double_splats = false + stats = false + progress = false + error_trace = false + + OptionParser.parse(options) do |opts| + opts.banner = <<-USAGE + Usage: typer [options] entrypoint [def_descriptor [def_descriptor [...]]] + + A def_descriptor comes in 4 formats: + + * A directory name ('src/') + * A file ('src/my_project.cr') + * A line number in a file ('src/my_project.cr:3') + * The location of the def method to be typed, specifically ('src/my_project.cr:3:3') + + If a `def` definition matches a provided def_descriptor, then it will be typed if type restrictions are missing. + If no dev_descriptors are provided, then 'src' is tried, or all files under current directory (and sub directories, recursive) + are typed if no 'src' directory exists. + + Options: + USAGE + + opts.on("-h", "--help", "Show this message") do + puts opts + exit + end + + opts.on("--prelude [PRELUDE]", "Use given file as prelude. Use empty string to skip prelude entirely.") do |new_prelude| + prelude = new_prelude + end + + opts.on("--include-blocks", "Enable adding types to named block arguments (these usually get typed with Proc(Nil) and isn't helpful)") do + type_blocks = true + end + + opts.on("--include-splats", "Enable adding types to splats") do + type_splats = true + end + + opts.on("--include-double-splats", "Enable adding types to double splats") do + type_double_splats = true + end + + opts.on("--stats", "Enable statistics output") do + stats = true + end + + opts.on("--progress", "Enable progress output") do + progress = true + end + + opts.on("--error-trace", "Show full error trace") do + error_trace = true + end + end + + entrypoint = options.shift + def_locators = options + + results = SourceTyper.new( + entrypoint, + def_locators, + type_blocks, + type_splats, + type_double_splats, + prelude, + stats, + progress, + error_trace + ).run + + if results.empty? + puts "No type restrictions added" + else + results.each do |filename, file_contents| + # pp! filename, file_contents + File.write(filename, file_contents) + end + end + end +end diff --git a/src/compiler/crystal/tools/typer.cr b/src/compiler/crystal/tools/typer.cr new file mode 100644 index 000000000000..ebaa9d70d034 --- /dev/null +++ b/src/compiler/crystal/tools/typer.cr @@ -0,0 +1,537 @@ +module Crystal + class SourceTyper + # Represents a fully typed definition signature + record Signature, + name : String, + return_type : Crystal::ASTNode, + location : Crystal::Location, + args = {} of String => Crystal::ASTNode + + getter program, files + + def initialize(@entrypoint : String, + @def_locators : Array(String), + @type_blocks : Bool, + @type_splats : Bool, + @type_double_splats : Bool, + @prelude : String = "prelude", + stats : Bool = false, + progress : Bool = false, + error_trace : Bool = false) + @entrypoint = File.expand_path(@entrypoint) unless @entrypoint.starts_with?("/") + @program = Crystal::Program.new + @files = Set(String).new + @warnings = [] of String + + @program.progress_tracker.stats = stats + @program.progress_tracker.progress = progress + @program.show_error_trace = error_trace + end + + # Run the entire typing flow, from semantic to file reformatting + def run : Hash(String, String) + semantic(@entrypoint, File.read(@entrypoint)) + + rets = {} of String => String + + @warnings.each do |warning| + puts "WARNING: #{warning}" + end + + @files.each do |file| + next unless File.file?(file) + source = File.read(file) + if typed_source = type_source(file, source) + rets[file] = typed_source + end + end + + rets + end + + # Take the entrypoint file (and its textual content) and run semantic on it. + # Semantic results are used to generate signatures for all defs that match + # at least one def_locator. + def semantic(entrypoint, entrypoint_content) : Nil + parser = program.new_parser(entrypoint_content) + parser.filename = entrypoint + parser.wants_doc = false + original_node = parser.parse + + nodes = Crystal::Expressions.from([original_node]) + + if !@prelude.empty? + # Prepend the prelude to the parsed program + location = Crystal::Location.new(entrypoint, 1, 1) + nodes = Crystal::Expressions.new([Crystal::Require.new(@prelude).at(location), nodes] of Crystal::ASTNode) + end + + program.normalize(nodes) + + # And now infer types of everything + semantic_node = program.semantic nodes, cleanup: true + + # Use the DefVisitor to locate and match any 'def's that match a def_locator + def_visitor = DefVisitor.new(@def_locators, entrypoint) + semantic_node.accept(def_visitor) + + # Hash up the location => (parsed) definition. + # At this point the types have been infeered (from semantic above) and stored in various + # def_instances in the `program` arg and its types. + accepted_defs = def_visitor.all_defs.map do |the_def| + { + the_def.location.to_s, + the_def, + } + end.to_h + init_signatures(accepted_defs) + + @files = def_visitor.files + end + + # Given a (presumably) crystal file and its content, re-format it with the crystal-formatter-that-types-things (SourceTyperFormatter). + # Returns nil if no type restrictions were added anywhere. + def type_source(filename, source) : String? + formatter = SourceTyperFormatter.new(source, signatures) + + parser = program.new_parser(source) + parser.filename = filename + parser.wants_doc = false + original_node = parser.parse + + formatter.skip_space_or_newline + original_node.accept formatter + + formatter.added_types? ? formatter.finish : nil + end + + # If a def is already fully typed, we don't need to check / write it + private def fully_typed?(d : Def) : Bool + ret = true + ret &= d.args.all?(&.restriction) + ret &= !!d.return_type + ret + end + + @_signatures : Hash(String, Signature)? + + # Signatures represents a mapping of location => Signature for def at that location + def signatures : Hash(String, Signature) + @_signatures || raise "Signatures not properly initialized!" + end + + # Given `accepted_defs` (location => (parsed) defs that match a def_locator), generated a new hash of + # location => (typed, multiple) def_instances that match a location. + # + # A given parsed def can have multiple def_instances, depending on how the method is called throughout + # the program, and the types of those calls. + private def accepted_def_instances(accepted_defs : Hash(String, Crystal::Def)) : Hash(String, Array(Crystal::Def)) + ret = Hash(String, Array(Crystal::Def)).new do |h, k| + h[k] = [] of Crystal::Def + end + + # First, check global definitions + program.def_instances.each do |_, def_instance| + next unless accepted_defs.keys.includes?(def_instance.location.to_s) + + ret[def_instance.location.to_s] << def_instance + end + + # Breadth first search time! This list will be a continuously populated queue of all of the types we need + # to scan, with newly discovered types added to the end of the queue from "parent" (namespace) types. + types = [] of Crystal::Type + + program.types.each { |_, t| types << t } + + while type = types.shift? + type.types?.try &.each { |_, t| types << t } + + # Check for class instance 'def's + if type.responds_to?(:def_instances) + type.def_instances.each do |_, def_instance| + next unless accepted_defs.keys.includes?(def_instance.location.to_s) + + ret[def_instance.location.to_s] << def_instance + end + end + + # Check for class 'self.def's + metaclass = type.metaclass + if metaclass.responds_to?(:def_instances) + metaclass.def_instances.each do |_, def_instance| + next unless accepted_defs.keys.includes?(def_instance.location.to_s) + + ret[def_instance.location.to_s] << def_instance + end + end + end + + ret + end + + # Given an 'arg', return its type that's good for printing (VirtualTypes suffix themselves with a '+') + private def resolve_type(arg) + t = arg.type + t.is_a?(Crystal::VirtualType) ? t.base_type : t + end + + # Strip out any NoReturns, or Procs that point to them (maybe all generics? Start with procs) + private def filter_no_return(types) + compacted_types = types.to_a.reject! do |type| + type.no_return? || (type.is_a?(Crystal::ProcInstanceType) && type.as(Crystal::ProcInstanceType).return_type.no_return?) + end + + compacted_types << program.nil if compacted_types.empty? + compacted_types + end + + # Generates a map of (parsed) Def#location => Signature for that Def + private def init_signatures(accepted_defs : Hash(String, Crystal::Def)) : Hash(String, Signature) + # This is hard to read, but transforms the def_instances array into a hash of def.location -> its full Signature + @_signatures ||= accepted_def_instances(accepted_defs).compact_map do |location, def_instances| + # Finally, combine all def_instances for a single def_obj_id into a single signature + + parsed = accepted_defs[location] + + all_typed_args = Hash(String, Set(Crystal::Type)).new { |h, k| h[k] = Set(Crystal::Type).new } + + # splats only exist in the parsed defs, while the def_instances have all had their splats "exploded". + # For typing splats, use the parsed defs for splat names and scan def_intances for various arg names that look... splatty. + safe_splat_index = parsed.splat_index || Int32::MAX + splat_arg_name = parsed.args[safe_splat_index]?.try &.name.try { |name| name.empty? ? nil : name } + named_arg_name = parsed.double_splat.try &.name + + encountered_non_splat_arg_def_instance = false + encountered_non_double_splat_arg_def_instance = false + + def_instances.each do |def_instance| + encountered_splat_arg = false + encountered_double_splat_arg = false + def_instance.args.each do |arg| + if arg.name == arg.external_name && !arg.name.starts_with?("__temp_") + # Regular arg + all_typed_args[arg.external_name] << resolve_type(arg) + elsif @type_splats && (splat_arg = splat_arg_name) && arg.name == arg.external_name && arg.name.starts_with?("__temp_") + # Splat arg, where the compiler generated a uniq name for it + encountered_splat_arg = true + all_typed_args[splat_arg] << resolve_type(arg) + elsif @type_double_splats && (named_arg = named_arg_name) && arg.name != arg.external_name && arg.name.starts_with?("__temp_") + # Named splat arg, where an "external" name was retained, but compiler generated uniq name for it + encountered_double_splat_arg = true + all_typed_args[named_arg] << resolve_type(arg) + elsif (!@type_splats || !@type_double_splats) && arg.name.starts_with?("__temp_") + # Ignore, it didn't fall into one of the above conditions (i.e. typing a particular splat wasn't specified) + else + raise "Unknown handling of arg #{arg} in #{def_instance}\n#{parsed}" + end + end + + encountered_non_splat_arg_def_instance |= !encountered_splat_arg + encountered_non_double_splat_arg_def_instance |= !encountered_double_splat_arg + + if @type_blocks && (arg = def_instance.block_arg) + all_typed_args[arg.external_name] << resolve_type(arg) + end + end + + # If a given collection of def_instances has a splat defined AND at least one def_instance didn't have a type for it, + # then we can't add types to the signature. + # https://crystal-lang.org/reference/1.14/syntax_and_semantics/type_restrictions.html#splat-type-restrictions + if @type_splats && (splat_arg = splat_arg_name) && encountered_non_splat_arg_def_instance + @warnings << "Not adding type restriction for splat #{splat_arg}, found empty splat call: #{parsed.location}" + all_typed_args.delete(splat_arg) + end + if @type_double_splats && (named_arg = named_arg_name) && encountered_non_double_splat_arg_def_instance + @warnings << "Not adding type restriction for double splat #{named_arg}, found empty deouble splat call: #{parsed.location}" + all_typed_args.delete(named_arg) + end + + # Convert each set of types into a single ASTNode (for easier printing) representing those types + all_args = all_typed_args.compact_map do |name, type_set| + compacted_types = filter_no_return(type_set) + + {name, to_ast(compacted_types)} + end.to_h + + # Similar idea for return_type to get into an easier to print state + returns = filter_no_return(def_instances.compact_map do |inst| + resolve_type(inst) + end.uniq!) + + return_type = to_ast(returns) + + {parsed.location.to_s, Signature.new( + name: parsed.name, + return_type: return_type, + location: parsed.location.not_nil!, + args: all_args + )} + end.to_h + end + + # Given a list of types, wrap them in a ASTNode appropriate for printing that type out + private def to_ast(types : Array(Crystal::Type)) + case types.size + when 1 + # Use var to communicate a single type name + Crystal::Var.new(types[0].to_s) + when 2 + if types.includes?(program.nil) + # One type is Nil, so write this using the slightly more human readable format with a '?' suffix + not_nil_type = types.reject(&.==(program.nil))[0] + Crystal::Var.new("#{not_nil_type}?") + else + Crystal::Union.new(types.map { |t| Crystal::Var.new(t.to_s).as(Crystal::ASTNode) }) + end + else + Crystal::Union.new(types.map { |t| Crystal::Var.new(t.to_s).as(Crystal::ASTNode) }) + end + end + + # Child class of the crystal formatter, but will write in type restrictions for the def return_type, or individual args, + # if there's a signature for a given def and those type restrictions are missing. + # + # All methods present are copy / paste from the original Crystal::Formatter for the given `visit` methods + class SourceTyperFormatter < Crystal::Formatter + @current_def : Crystal::Def? = nil + getter? added_types = false + + def initialize(source : String, @signatures : Hash(String, Signature)) + # source = File.read(filename) + super(source) + end + + def visit(node : Crystal::Def) + @implicit_exception_handler_indent = @indent + @inside_def += 1 + @vars.push Set(String).new + + write_keyword :abstract, " " if node.abstract? + + write_keyword :def, " ", skip_space_or_newline: false + + if receiver = node.receiver + skip_space_or_newline + accept receiver + skip_space_or_newline + write_token :OP_PERIOD + end + + @lexer.wants_def_or_macro_name do + skip_space_or_newline + end + + write node.name + + indent do + next_token + + # this formats `def foo # ...` to `def foo(&) # ...` for yielding + # methods before consuming the comment line + if node.block_arity && node.args.empty? && !node.block_arg && !node.double_splat + write "(&)" + end + + skip_space consume_newline: false + next_token_skip_space if @token.type.op_eq? + end + + # ===== BEGIN NEW CODE ===== + # Wrap the format_def_args call with a quick-to-reach reference to the current def (for signature lookup) + @current_def = node + to_skip = format_def_args node + @current_def = nil + # ===== END NEW CODE ===== + + if return_type = node.return_type + skip_space + write_token " ", :OP_COLON, " " + skip_space_or_newline + accept return_type + # ===== BEGIN NEW CODE ===== + # If the def doesn't already have a type restriction and we have a signature for this method, write in the return_type + elsif (sig = @signatures[node.location.to_s]?) && sig.name != "initialize" + skip_space + write " : #{sig.return_type}" + @added_types = true + # ===== END NEW CODE ===== + end + + if free_vars = node.free_vars + skip_space_or_newline + write " forall " + next_token + last_index = free_vars.size - 1 + free_vars.each_with_index do |free_var, i| + skip_space_or_newline + check :CONST + write free_var + next_token + skip_space_or_newline if last_index != i + if @token.type.op_comma? + write ", " + next_token_skip_space_or_newline + end + end + end + + body = remove_to_skip node, to_skip + + unless node.abstract? + format_nested_with_end body + end + + @vars.pop + @inside_def -= 1 + + false + end + + def visit(node : Crystal::Arg) + @last_arg_is_skip = false + + restriction = node.restriction + default_value = node.default_value + + if @inside_lib > 0 + # This is the case of `fun foo(Char)` + if !@token.type.ident? && restriction + accept restriction + return false + end + end + + if node.name.empty? + skip_space_or_newline + else + @vars.last.add(node.name) + + at_skip = at_skip? + + if !at_skip && node.external_name != node.name + if node.external_name.empty? + write "_" + elsif @token.type.delimiter_start? + accept Crystal::StringLiteral.new(node.external_name) + else + write @token.value + end + write " " + next_token_skip_space_or_newline + end + + @last_arg_is_skip = at_skip? + + write @token.value + next_token + end + + if restriction + skip_space_or_newline + write_token " ", :OP_COLON, " " + skip_space_or_newline + accept restriction + # ===== BEGIN NEW CODE ===== + # If the current arg doesn't have a restriction already and we have a signature, write in the type restriction + elsif (sig = @signatures[@current_def.try &.location.to_s || 0_u64]?) && sig.args[node.name]? + skip_space_or_newline + write " : #{sig.args[node.name]}" + @added_types = true + # ===== END NEW CODE ===== + end + + if default_value + # The default value might be a Proc with args, so + # we need to remember this and restore it later + old_last_arg_is_skip = @last_arg_is_skip + + skip_space_or_newline + + check_align = check_assign_length node + write_token " ", :OP_EQ, " " + before_column = @column + skip_space_or_newline + accept default_value + check_assign_align before_column, default_value if check_align + + @last_arg_is_skip = old_last_arg_is_skip + end + + # This is the case of an enum member + if @token.type.op_semicolon? + next_token + @lexer.skip_space + if @token.type.comment? + write_comment + @exp_needs_indent = true + else + write ";" if @token.type.const? + write " " + @exp_needs_indent = @token.type.newline? + end + end + + false + end + end + + # A visitor for defs, oddly enough. + # + # Walk through the AST and capture all references to Defs that match a def_locator + class DefVisitor < Crystal::Visitor + getter all_defs = Array(Crystal::Def).new + getter files = Set(String).new + + CRYSTAL_LOCATOR_PARSER = /^.*\.cr(:(?\d+))?(:(?\d+))?$/ + + @dir_locators : Array(String) + @file_locators : Array(String) = [] of String + @line_locators : Array(String) = [] of String + @line_and_column_locators : Array(String) = [] of String + + def initialize(@def_locators : Array(String), entrypoint) + if @def_locators.empty? + entrypoint_dir = File.dirname(entrypoint) + # Nothing was provided, is the entrypoint in the `src` directory? + if entrypoint_dir.ends_with?("/src") || entrypoint_dir.includes?("/src/") + @def_locators << File.dirname(entrypoint_dir) + else + # entrypoint isn't in a 'src' directory, assume we should only type it, and only it, wherever it is + @def_locators << entrypoint + end + end + + def_locs = @def_locators.map { |p| File.expand_path(p) } + @dir_locators = def_locs.reject(&.match(CRYSTAL_LOCATOR_PARSER)) + def_locs.compact_map(&.match(CRYSTAL_LOCATOR_PARSER)).each do |loc| + @file_locators << loc[0] unless loc["line_number"]? + @line_locators << loc[0] unless loc["col_number"]? + @line_and_column_locators << loc[0] if loc["line_number"]? && loc["col_number"]? + end + end + + def visit(node : Crystal::Def) + return false unless loc = node.location + return false unless loc.filename && loc.line_number && loc.column_number + if node_in_def_locators(loc) + all_defs << node + files << loc.filename.to_s + end + + false + end + + def visit(node : Crystal::ASTNode) + true + end + + private def node_in_def_locators(location : Crystal::Location) : Bool + return false unless location.to_s.starts_with?("/") || location.to_s.starts_with?(/\w:/) + return true if @dir_locators.any? { |d| location.filename.to_s.starts_with?(d) } + return true if @file_locators.includes?(location.filename) + return true if @line_locators.includes?("#{location.filename}:#{location.line_number}") + @line_and_column_locators.includes?("#{location.filename}:#{location.line_number}:#{location.column_number}") + end + end + end +end