diff --git a/analysis_options.yaml b/analysis_options.yaml index fa798a8..bb1d822 100644 --- a/analysis_options.yaml +++ b/analysis_options.yaml @@ -2,3 +2,4 @@ include: package:very_good_analysis/analysis_options.5.1.0.yaml linter: rules: public_member_api_docs: false + sort_constructors_first: false diff --git a/lib/src/command_runner.dart b/lib/src/command_runner.dart index 27ae2d5..3691b84 100644 --- a/lib/src/command_runner.dart +++ b/lib/src/command_runner.dart @@ -1,6 +1,7 @@ import 'dart:io'; import 'package:ai_commit/src/commands/commands.dart'; +import 'package:ai_commit/src/utils/utils.dart'; import 'package:ai_commit/src/version.dart'; import 'package:args/args.dart'; import 'package:args/command_runner.dart'; @@ -11,7 +12,7 @@ import 'package:pub_updater/pub_updater.dart'; const executableName = 'ai_commit'; const packageName = 'ai_commit'; -const description = 'A Very Good Project created by Very Good CLI.'; +const description = 'Dart CLI for generate commit messages with OpenAI.'; /// {@template ai_commit_command_runner} /// A [CommandRunner] for the CLI. @@ -36,18 +37,21 @@ class AiCommitCommandRunner extends CompletionCommandRunner { negatable: false, help: 'Automatically stage changes in tracked files for the commit', ) - ..addFlag( + ..addOption( 'exclude', abbr: 'x', - negatable: false, help: 'Files to exclude from AI analysis', ) + ..addOption( + 'count', + abbr: 'c', + help: ''' +Count of messages to generate (Warning: generating multiple costs more)''', + ) ..addFlag( - 'generate', - abbr: 'g', - negatable: false, - help: - 'Number of messages to generate (Warning: generating multiple costs more)', + 'conventional', + help: ''' +Format the commit message according to the Conventional Commits specification.''', ) ..addFlag( 'version', @@ -73,6 +77,12 @@ class AiCommitCommandRunner extends CompletionCommandRunner { @override Future run(Iterable args) async { + final gitRepoPath = assetGitRepo(); + if (gitRepoPath.isEmpty) { + _logger.info('The current directory must be a git repository.'); + return ExitCode.software.code; + } + String? home; final envVars = Platform.environment; @@ -146,8 +156,18 @@ class AiCommitCommandRunner extends CompletionCommandRunner { if (topLevelResults['version'] == true) { _logger.info(packageVersion); exitCode = ExitCode.success.code; - } else { + } else if (topLevelResults.command?.name == ConfigCommand.commandName || + topLevelResults['help'] == true) { exitCode = await super.runCommand(topLevelResults); + } else { + exitCode = await _startWork( + all: topLevelResults.wasParsed('all') ? topLevelResults['all'] : null, + count: topLevelResults['count'], + exclude: topLevelResults['exclude'], + conventional: topLevelResults.wasParsed('conventional') + ? topLevelResults['conventional'] + : null, + ); } // Check for updates @@ -166,12 +186,158 @@ class AiCommitCommandRunner extends CompletionCommandRunner { final latestVersion = await _pubUpdater.getLatestVersion(packageName); final isUpToDate = packageVersion == latestVersion; if (!isUpToDate) { - _logger..info('')..info( - ''' + _logger + ..info('') + ..info( + ''' ${lightYellow.wrap('Update available!')} ${lightCyan.wrap(packageVersion)} \u2192 ${lightCyan.wrap(latestVersion)} Run ${lightCyan.wrap('$executableName update')} to update''', - ); + ); } } catch (_) {} } + + Future _startWork({ + required dynamic all, + required dynamic count, + required dynamic exclude, + required dynamic conventional, + }) async { + try { + final apiKey = await getKey(); + if (apiKey == null) { + _logger.info( + '''No API key found. To generate one, run ${lightCyan.wrap('$executableName config')}''', + ); + return ExitCode.software.code; + } + + if (all == true) await Process.run('git', ['add', '--update']); + + int? msgCount; + + if (count != null) { + final value = int.tryParse('$count'); + if (value == null) { + _logger.err('Count must be an integer.'); + + return ExitCode.software.code; + } + if (value < 0) { + _logger.err('Count must be an greater than 0.'); + return ExitCode.software.code; + } + + if (value > 5) { + _logger.err('Count must be less than or equal to 5.'); + return ExitCode.software.code; + } + + msgCount = value; + } + + var excludeFiles = []; + + if (exclude != null) { + excludeFiles = [ + for (final e in exclude.toString().split(',')) e.trim() + ]; + } + + final detectingFiles = _logger.progress('Detecting staged files'); + + final staged = await getStagedDiff(excludeFiles: excludeFiles); + + if (staged.isEmpty) { + detectingFiles.complete('Detecting staged files'); + _logger.info( + ''' +No staged changes found. Stage your changes manually, or automatically stage all changes with the `--all` options.''', + ); + return ExitCode.success.code; + } + + final files = staged['files'] as List? ?? []; + + var message = getDetectedMessage(files: files); + + detectingFiles.complete( + '$message:\n${files.map((e) => ' $e').join('\n')}', + ); + + final s = _logger.progress('The AI is analyzing your changes'); + + var messages = []; + + final locale = await getLocale(); + final maxLength = await getMaxLength(); + + var completions = 1; + + if (msgCount != null) { + completions = msgCount; + } else { + completions = await getCount(); + } + + var isConventional = false; + + if (conventional != null) { + isConventional = conventional as bool; + } else { + isConventional = await getConventional(); + } + + messages = await generateCommitMessage( + apiKey: apiKey, + locale: locale, + logger: _logger, + maxLength: maxLength, + completions: completions, + diff: staged['diff'] as String, + isConventional: isConventional, + ); + + s.fail('Changes analyzed'); + + if (messages.isEmpty) { + _logger.info('No commit messages were generated. Try again.'); + return ExitCode.success.code; + } + + if (messages.length == 1) { + message = messages.first; + + final confirmed = + _logger.confirm('Use this commit message?\n\n $message\n'); + + if (!confirmed) { + _logger.info('Commit message canceled.'); + return ExitCode.software.code; + } + } else { + final selected = _logger.chooseOne( + 'Pick a commit message to use: ', + choices: messages, + ); + + if (selected.isEmpty) { + _logger.info('Commit message canceled.'); + return ExitCode.software.code; + } + + message = selected; + } + + await Process.run('git', ['commit', '-m', message]); + + s.complete('Successfully committed'); + + return ExitCode.success.code; + } catch (e) { + print(e.runtimeType); + print(e); + exit(0); + } + } } diff --git a/lib/src/commands/config_command.dart b/lib/src/commands/config_command.dart index bb045fc..1978ced 100644 --- a/lib/src/commands/config_command.dart +++ b/lib/src/commands/config_command.dart @@ -24,18 +24,10 @@ class ConfigCommand extends Command { help: ''' Format the commit message according to the Conventional Commits specification.''', ) - ..addOption( - 'proxy', - help: 'Set proxy server for OpenAI API.', - ) ..addOption( 'model', help: 'Set model name for OpenAI API.', ) - ..addOption( - 'timeout', - help: 'Set timeout in milliseconds.', - ) ..addOption( 'max-length', help: 'Set max length of commit message.', @@ -45,6 +37,8 @@ Format the commit message according to the Conventional Commits specification.'' @override String get description => 'ai_commit configuration'; + static const String commandName = 'config'; + @override String get name => 'config'; @@ -100,24 +94,25 @@ Format the commit message according to the Conventional Commits specification.'' final count = argResults?['count']; if (count != null) { - if (int.tryParse('$count') != null) { - final value = int.parse('$count'); - if (value < 0) { - _logger.err('Count must be an greater than 0.'); - return ExitCode.software.code; - } - - if (value > 5) { - _logger.err('Count must be less than or equal to 5.'); - return ExitCode.software.code; - } - - _logger.success('Setting "count" to "$value".'); - await setCount(value); - } else { + final value = int.tryParse('$count'); + + if (value == null) { _logger.err('Count must be an integer.'); return ExitCode.software.code; } + + if (value < 0) { + _logger.err('Count must be an greater than 0.'); + return ExitCode.software.code; + } + + if (value > 5) { + _logger.err('Count must be less than or equal to 5.'); + return ExitCode.software.code; + } + + _logger.success('Setting "count" to "$count".'); + await setCount(value); } // get `conventional` value from args and store @@ -132,52 +127,19 @@ Format the commit message according to the Conventional Commits specification.'' } } - // get `proxy` value from args - // check valid URL and store - - final proxy = argResults?['proxy']; - - if (proxy != null) { - if (Uri.tryParse('$proxy')?.isAbsolute ?? false) { - _logger.success('Setting "proxy" to "$proxy".'); - await setProxy('$proxy'); - } else { - _logger.err('Proxy must be valid URL.'); - return ExitCode.software.code; - } - } - // get `model` value from args // check isNotEmpty and store final model = argResults?['model']; if (model != null) { - if (model is String && model.isNotEmpty) { - _logger.success('Setting "model" to "$model".'); - await setModel(model); - } else { + if ('$model'.isEmpty) { _logger.err('Model must be a string.'); return ExitCode.software.code; } - } - // get `timeout` value from args - // check greater than 500ms and store - - final timeout = argResults?['timeout']; - - if (timeout != null) { - if (int.tryParse('$timeout') != null) { - final value = int.parse('$timeout'); - if (value > 500) { - _logger.success('Setting "timeout" to "$value".'); - await setProxy('$value'); - } else { - _logger.err('Timeout must be an greater than 500ms.'); - return ExitCode.software.code; - } - } + _logger.success('Setting "model" to "$model".'); + await setModel('$model'); } // get `max-length` value from args @@ -186,19 +148,20 @@ Format the commit message according to the Conventional Commits specification.'' final maxLength = argResults?['max-length']; if (maxLength != null) { - if (int.tryParse('$maxLength') != null) { - final value = int.parse('$maxLength'); - if (value > 20) { - _logger.success('Setting "max-length" to "$value".'); - await setProxy('$value'); - } else { - _logger.err('Max length must be an greater than 20.'); - return ExitCode.software.code; - } - } else { + final value = int.tryParse('$maxLength'); + + if (value == null) { _logger.err('Max length must be an integer.'); return ExitCode.software.code; } + + if (value < 20) { + _logger.err('Max length must be an greater than 20.'); + return ExitCode.software.code; + } + + _logger.success('Setting "max-length" to "$value".'); + await setMaxLength(value); } return ExitCode.software.code; diff --git a/lib/src/utils/config.dart b/lib/src/utils/config.dart index 4182554..1f684a3 100644 --- a/lib/src/utils/config.dart +++ b/lib/src/utils/config.dart @@ -1,34 +1,62 @@ import 'package:hive/hive.dart'; +String defaultModel = 'gpt-3.5-turbo-1106'; +int defaultMaxLength = 20; + // store data in hive box Future setData(String key, dynamic value) async { final box = await Hive.openBox('ai_commit_config'); await box.put(key, value); } +// get data from hive box +Future getData(String key, [dynamic defaultValue]) async { + final box = await Hive.openBox('ai_commit_config'); + return box.get(key) ?? defaultValue; +} + // store openai api key Future setKey(String value) async { await setData('api_key', value); } +// get openai api key +Future getKey() async { + final value = await getData('api_key'); + return value as String?; +} + // store locale language Future setLocale(String value) async { await setData('locale', value); } +// get locale language +Future getLocale() async { + final value = await getData('locale'); + return value as String?; +} + // store generate commit message count Future setCount(int value) async { await setData('generate_count', value); } +// get generate commit message count +Future getCount() async { + final value = await getData('generate_count',1); + return value as int; +} + // store is_conventional value Future setConventional({required bool value}) async { await setData('is_conventional', value); } -// store proxy server address -Future setProxy(String value) async { - await setData('proxy', value); +// get is_conventional value +Future getConventional() async { + final value = await getData('is_conventional', false); + return value as bool; } // store model name @@ -36,13 +64,19 @@ Future setModel(String value) async { await setData('model', value); } -// store timeout value -Future setTimeout(int value) async { - await setData('timeout', value); +// get model name +Future getModel() async { + final value = await getData('model', defaultModel); + return value as String; } -void getConfig() { - final config = Hive.box>('ai_commit_config'); +// store max length +Future setMaxLength(int value) async { + await setData('max_length', value); +} - for (var key in config.keys) {} +// get max length +Future getMaxLength() async { + final value = await getData('max_length', defaultMaxLength); + return value as int; } diff --git a/lib/src/utils/error.dart b/lib/src/utils/error.dart new file mode 100644 index 0000000..121c9a4 --- /dev/null +++ b/lib/src/utils/error.dart @@ -0,0 +1,20 @@ +import 'package:ai_commit/src/command_runner.dart'; +import 'package:ai_commit/src/version.dart'; +import 'package:mason_logger/mason_logger.dart'; + +class KnownError extends Error { + final Object? message; + + KnownError([this.message]); + + @override + String toString() { + Logger() + ..err(message.toString()) + ..err('\n$packageName v$packageVersion') + ..err('\nPlease open a Bug report with the information above:') + ..err('https://github.com/thitlwincoder/ai_commit/issues/new/choose'); + + return ''; + } +} diff --git a/lib/src/utils/git.dart b/lib/src/utils/git.dart new file mode 100644 index 0000000..e7d57d6 --- /dev/null +++ b/lib/src/utils/git.dart @@ -0,0 +1,61 @@ +import 'dart:io'; + +String assetGitRepo() { + final result = Process.runSync('git', ['rev-parse', '--show-toplevel']); + return result.stdout.toString(); +} + +String excludeFromDiff(String path) => ':(exclude)$path'; + +Future> getStagedDiff({ + List? excludeFiles, +}) async { + // read .gitignore file from project path + // remove comments and empty lines + + final gitignoreFiles = []; + + final gitignore = File('.gitignore'); + if (gitignore.existsSync()) { + gitignore.readAsLinesSync().forEach((line) { + if (!line.startsWith('#') && line.isNotEmpty) { + gitignoreFiles.add(excludeFromDiff(line)); + } + }); + } + + final diffCached = ['diff', '--cached', '--diff-algorithm=minimal']; + + var result = await Process.run( + 'git', + [ + ...diffCached, + '--name-only', + ...gitignoreFiles, + if (excludeFiles != null) ...excludeFiles.map(excludeFromDiff), + ], + ); + + final files = result.stdout.toString(); + + if (files.isEmpty) return {}; + + result = await Process.run( + 'git', + [ + ...diffCached, + ...gitignoreFiles, + if (excludeFiles != null) ...excludeFiles.map(excludeFromDiff), + ], + ); + + final diff = result.stdout.toString(); + + return {'files': files.split('\n'), 'diff': diff}; +} + +String getDetectedMessage({List? files}) { + files ??= []; + + return 'Detected ${files.length} staged file${files.length == 1 ? '' : 's'}'; +} diff --git a/lib/src/utils/openai.dart b/lib/src/utils/openai.dart new file mode 100644 index 0000000..04619b1 --- /dev/null +++ b/lib/src/utils/openai.dart @@ -0,0 +1,105 @@ +import 'package:ai_commit/src/utils/utils.dart'; +import 'package:dart_openai/dart_openai.dart'; +import 'package:http/http.dart' as http; +import 'package:mason_logger/mason_logger.dart'; + +Future createChatCompletion({ + required Logger logger, + required String apiKey, + required Map data, +}) async { + final uri = Uri.https( + 'api.openai.com', + '/v1/chat/completions', + {'authorization': 'Bearer $apiKey'}, + ); + + final r = await http.post(uri, body: data); + + final statusCode = r.statusCode; + + if (statusCode < 200 || statusCode > 299) { + var errorMessage = 'OpenAI API Error: $statusCode\n\n${r.body}'; + + if (statusCode == 500) { + errorMessage += '\n\nCheck the API status: https://status.openai.com'; + } + + logger.info(errorMessage); + return ''; + } + + return r.body; +} + +String sanitizeMessage(String message) { + return message.trim().replaceAll('[\n\r]', '').replaceAll(r'(\w)\.$', r'\$1'); +} + +List deduplicateMessages(List messages) { + return List.from(messages.toSet()); +} + +Future> generateCommitMessage({ + required String apiKey, + required String? locale, + required String diff, + required int completions, + required int maxLength, + required bool isConventional, + required Logger logger, + String? model, +}) async { + model ??= 'gpt-3.5-turbo-1106'; + try { + OpenAI.apiKey = apiKey; + final completion = await OpenAI.instance.chat.create( + topP: 1, + model: model, + maxTokens: 200, + n: completions, + temperature: .7, + presencePenalty: 0, + frequencyPenalty: 0, + messages: [ + OpenAIChatCompletionChoiceMessageModel( + role: OpenAIChatMessageRole.system, + content: [ + OpenAIChatCompletionChoiceMessageContentItemModel.text( + generatePrompt( + locale: locale, + maxLength: maxLength, + isConventional: isConventional, + ), + ), + ], + ), + OpenAIChatCompletionChoiceMessageModel( + role: OpenAIChatMessageRole.user, + content: [ + OpenAIChatCompletionChoiceMessageContentItemModel.text(diff), + ], + ), + ], + ); + + final contents = []; + + for (final choice in completion.choices) { + final content = choice.message.content; + if (content != null) contents.addAll(content); + } + + final messages = []; + + for (final content in contents) { + final text = content.text; + + if (text != null) messages.add(sanitizeMessage(text)); + } + + return deduplicateMessages(messages); + } catch (e) { + throw KnownError(e); + } +} diff --git a/lib/src/utils/prompt.dart b/lib/src/utils/prompt.dart new file mode 100644 index 0000000..55438d4 --- /dev/null +++ b/lib/src/utils/prompt.dart @@ -0,0 +1,43 @@ +import 'dart:convert'; + +String get conventionalPrompt { + var map = { + 'docs': 'Documentation only changes', + 'style': ''' +Changes that do not affect the meaning of the code (white-space, formatting, missing semi-colons, etc)''', + 'refactor': 'A code change that neither fixes a bug nor adds a feature', + 'perf': 'A code change that improves performance', + 'test': 'Adding missing tests or correcting existing tests', + 'build': 'Changes that affect the build system or external dependencies', + 'ci': 'Changes to our CI configuration files and scripts', + 'chore': "Other changes that don't modify src or test files", + 'revert': 'Reverts a previous commit', + 'feat': 'A new feature', + 'fix': 'A bug fix', + }; + + return ''' +Choose a type from the type-to-description JSON below that best describes the git diff: +${jsonEncode(map)}'''; +} + +String specifyCommitFormat = + 'The output response must be in format:\n $conventionalPrompt'; + +String generatePrompt({ + required String? locale, + required int maxLength, + required bool isConventional, +}) { + final msgs = [ + ''' +Generate a concise git commit message written in present tense for the following code diff with the given specifications below:''', + if (locale != null) 'Message language: $locale', + 'Commit message must be a maximum of $maxLength characters.', + ''' +Exclude anything unnecessary such as translation. Your entire response will be passed directly into git commit.''', + if (isConventional) ...[conventionalPrompt, specifyCommitFormat], + ]; + + return msgs.join('\n'); +} diff --git a/lib/src/utils/utils.dart b/lib/src/utils/utils.dart index ee0ca3b..c778fbf 100644 --- a/lib/src/utils/utils.dart +++ b/lib/src/utils/utils.dart @@ -1,2 +1,7 @@ //GENERATED BARREL FILE -export 'config.dart'; +export 'config.dart'; +export 'error.dart'; +export 'git.dart'; +export 'openai.dart'; +export 'prompt.dart'; + diff --git a/pubspec.yaml b/pubspec.yaml index 9c5a94f..28cf360 100644 --- a/pubspec.yaml +++ b/pubspec.yaml @@ -9,7 +9,9 @@ environment: dependencies: args: ^2.4.2 cli_completion: ^0.4.0 + dart_openai: ^5.0.0 hive: ^2.2.3 + http: ^1.1.1 mason_logger: ^0.2.10 pub_updater: ^0.4.0