Skip to content

Commit

Permalink
Merge pull request #6 from retrage/use-async-openai
Browse files Browse the repository at this point in the history
build: Use async-openai crate
  • Loading branch information
retrage authored Sep 24, 2023
2 parents 7821c91 + b85e107 commit 8667c41
Show file tree
Hide file tree
Showing 9 changed files with 126 additions and 385 deletions.
8 changes: 2 additions & 6 deletions Cargo.toml
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,6 @@ path = "tests/tests.rs"

[features]
default = []
davinci = []

[dev-dependencies]
trybuild = { version = "1.0", features = ["diff"] }
Expand All @@ -21,8 +20,5 @@ trybuild = { version = "1.0", features = ["diff"] }
syn = { version = "2.0", features = ["full", "extra-traits", "parsing"] }
proc-macro2 = { version = "1.0", features = ["nightly"] }
quote = "1.0"
hyper = { version = "0.14", features = ["full"] }
hyper-tls = "0.5"
tokio = { version = "1.0", features = ["full"] }
serde_json = "1.0"
serde = { version = "1.0", features = ["derive"] }
async-openai = "0.14.2"
tokio = { version = "1.0", features = ["rt-multi-thread"] }
5 changes: 0 additions & 5 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -86,11 +86,6 @@ fn div_u32(a: u32, b: u32) -> u32 {
}
```

## Supported Models

* ChatGPT: `gpt-3.5-turbo` (default)
* Text Completion: `text-davinci-003` (Specify `davinci` feature to enable it)

## License

gpt-macro is released under the MIT license.
5 changes: 1 addition & 4 deletions src/internal.rs
Original file line number Diff line number Diff line change
@@ -1,9 +1,6 @@
// SPDX-License-Identifier: MIT
// Akira Moroo <[email protected]> 2023

mod chatgpt;
mod completion;
mod text_completion;

pub mod auto_impl;
pub mod auto_test;
mod utils;
72 changes: 41 additions & 31 deletions src/internal/auto_impl.rs
Original file line number Diff line number Diff line change
@@ -1,55 +1,69 @@
// SPDX-License-Identifier: MIT
// Akira Moroo <[email protected]> 2023

use async_openai::{
types::{ChatCompletionRequestMessageArgs, CreateChatCompletionRequestArgs, Role},
Client,
};
use proc_macro::TokenStream;
use quote::quote;
use syn::{
parse::{Parse, ParseStream},
parse_macro_input, parse_str, LitStr,
};
use tokio::runtime::Runtime;

use crate::internal::completion::CodeCompletion;
use super::utils;

/// Parses the following syntax:
///
/// auto_impl! {
/// $STR_LIT
/// $TOKEN_STREAM
/// }
struct AutoImpl<C: CodeCompletion> {
struct AutoImpl {
doc: String,
token_stream: proc_macro2::TokenStream,
code_completion: C,
}

impl<C: CodeCompletion> Parse for AutoImpl<C> {
impl Parse for AutoImpl {
fn parse(input: ParseStream) -> syn::Result<Self> {
let doc = input.parse::<LitStr>()?.value();
let token_stream = input.parse::<proc_macro2::TokenStream>()?;
Ok(AutoImpl {
doc,
token_stream,
code_completion: C::new(),
})
Ok(AutoImpl { doc, token_stream })
}
}

impl<C: CodeCompletion> AutoImpl<C> {
fn completion(&mut self) -> Result<TokenStream, Box<dyn std::error::Error>> {
let init_prompt = "You are a Rust expert who can implement the given function.";
self.code_completion.init(init_prompt.to_string());
self.code_completion.add_context(format!(
"Read this incomplete Rust code:\n```rust\n{}\n```",
self.token_stream
));
self.code_completion.add_context(format!(
"Complete the Rust code that follows this instruction: '{}'. Your response must start with code block '```rust'.",
self.doc
));
impl AutoImpl {
async fn completion(&mut self) -> Result<TokenStream, Box<dyn std::error::Error>> {
let request = CreateChatCompletionRequestArgs::default()
.model("gpt-3.5-turbo")
.messages([
ChatCompletionRequestMessageArgs::default()
.role(Role::System)
.content("You are a Rust expert who can implement the given function.")
.build()?,
ChatCompletionRequestMessageArgs::default()
.role(Role::User)
.content(format!(
"Read this incomplete Rust code:\n```rust\n{}\n```",
self.token_stream
))
.build()?,
ChatCompletionRequestMessageArgs::default()
.role(Role::User)
.content(format!(
"Complete the Rust code that follows this instruction: '{}'. Your response must start with code block '```rust'.",
self.doc
))
.build()?,
])
.build()?;

let code_text = self.code_completion.code_completion()?;
let client = Client::new();
let response = client.chat().create(request).await?;

self.parse_str(&code_text)
self.parse_str(&utils::extract_code(&response)?)
}

fn parse_str(&self, s: &str) -> Result<TokenStream, Box<dyn std::error::Error>> {
Expand All @@ -66,13 +80,9 @@ impl<C: CodeCompletion> AutoImpl<C> {
}

pub fn auto_impl_impl(input: TokenStream) -> TokenStream {
#[cfg(not(feature = "davinci"))]
type Backend = crate::internal::chatgpt::ChatGPT;

#[cfg(feature = "davinci")]
type Backend = crate::internal::text_completion::TextCompletion;

let mut auto_impl = parse_macro_input!(input as AutoImpl<Backend>);
let mut auto_impl = parse_macro_input!(input as AutoImpl);

auto_impl.completion().unwrap_or_else(|e| panic!("{}", e))
let rt = Runtime::new().expect("Failed to create a runtime.");
rt.block_on(auto_impl.completion())
.unwrap_or_else(|e| panic!("{}", e))
}
87 changes: 54 additions & 33 deletions src/internal/auto_test.rs
Original file line number Diff line number Diff line change
@@ -1,15 +1,20 @@
// SPDX-License-Identifier: MIT
// Akira Moroo <[email protected]> 2023

use async_openai::{
types::{ChatCompletionRequestMessageArgs, CreateChatCompletionRequestArgs, Role},
Client,
};
use proc_macro::TokenStream;
use quote::{quote, ToTokens};
use std::collections::HashSet;
use syn::{
parse::{Parse, ParseStream},
parse_macro_input, parse_str, Ident, Token,
};
use tokio::runtime::Runtime;

use crate::internal::completion::CodeCompletion;
use super::utils;

/// Parses a list of test function names separated by commas.
///
Expand All @@ -29,48 +34,68 @@ impl Parse for Args {
}
}

struct AutoTest<C: CodeCompletion> {
struct AutoTest {
token_stream: proc_macro2::TokenStream,
code_completion: C,
}

impl<C: CodeCompletion> AutoTest<C> {
pub fn new(token_stream: proc_macro2::TokenStream) -> Self {
Self {
token_stream,
code_completion: C::new(),
}
impl AutoTest {
fn new(token_stream: proc_macro2::TokenStream) -> Self {
Self { token_stream }
}

pub fn completion(&mut self, args: Args) -> Result<TokenStream, Box<dyn std::error::Error>> {
async fn completion(&mut self, args: Args) -> Result<TokenStream, Box<dyn std::error::Error>> {
let mut output = self.token_stream.clone();

let init_prompt =
"You are a Rust expert who can generate perfect tests for the given function.";
self.code_completion.init(init_prompt.to_string());
self.code_completion.add_context(format!(
"Read this Rust function:\n```rust\n{}\n```",
self.token_stream,
));
let mut messages = vec![
ChatCompletionRequestMessageArgs::default()
.role(Role::System)
.content(
"You are a Rust expert who can generate perfect tests for the given function.",
)
.build()?,
ChatCompletionRequestMessageArgs::default()
.role(Role::User)
.content(format!(
"Read this Rust function:\n```rust\n{}\n```",
self.token_stream
))
.build()?,
];

if args.test_names.is_empty() {
self.code_completion.add_context(
"Write a test case for the function as much as possible in Markdown code snippet style. Your response must start with code block '```rust'.".to_string()
messages.push(
ChatCompletionRequestMessageArgs::default()
.role(Role::User)
.content(
"Write a test case for the function as much as possible in Markdown code snippet style. Your response must start with code block '```rust'.",
)
.build()?,
);
} else {
for test_name in args.test_names {
self.code_completion.add_context(
format!(
"Write a test case `{}` for the function in Markdown code snippet style. Your response must start with code block '```rust'.",
test_name
)
messages.push(
ChatCompletionRequestMessageArgs::default()
.role(Role::User)
.content(
format!(
"Write a test case `{}` for the function in Markdown code snippet style. Your response must start with code block '```rust'.",
test_name
)
)
.build()?,
);
}
}

let test_text = self.code_completion.code_completion()?;
let request = CreateChatCompletionRequestArgs::default()
.model("gpt-3.5-turbo")
.messages(messages)
.build()?;

let client = Client::new();
let response = client.chat().create(request).await?;

let test_case = self.parse_str(&test_text)?;
let test_case = self.parse_str(&utils::extract_code(&response)?)?;
test_case.to_tokens(&mut output);

Ok(TokenStream::from(output))
Expand All @@ -93,13 +118,9 @@ pub fn auto_test_impl(args: TokenStream, input: TokenStream) -> TokenStream {
// Parse the list of test function names that should be generated.
let args = parse_macro_input!(args as Args);

#[cfg(not(feature = "davinci"))]
type Backend = crate::internal::chatgpt::ChatGPT;

#[cfg(feature = "davinci")]
type Backend = crate::internal::text_completion::TextCompletion;
let mut auto_test = AutoTest::new(input.into());

AutoTest::<Backend>::new(input.into())
.completion(args)
let rt = Runtime::new().expect("Failed to create a runtime.");
rt.block_on(auto_test.completion(args))
.unwrap_or_else(|e| panic!("{}", e))
}
Loading

0 comments on commit 8667c41

Please sign in to comment.