Skip to content

Commit

Permalink
Refactor to make providing the model slug optional, relying on a defa…
Browse files Browse the repository at this point in the history
…ult model per service.
  • Loading branch information
felixarntz committed Sep 12, 2024
1 parent 9bfb2d5 commit 2173382
Show file tree
Hide file tree
Showing 7 changed files with 103 additions and 76 deletions.
35 changes: 27 additions & 8 deletions includes/Google/Google_AI_Service.php
Original file line number Diff line number Diff line change
Expand Up @@ -72,6 +72,17 @@ public function get_capabilities(): array {
return AI_Capabilities::get_model_class_capabilities( Google_AI_Model::class );
}

/**
* Gets the default model slug to use with the service when none is provided.
*
* @since n.e.x.t
*
* @return string The default model slug.
*/
public function get_default_model_slug(): string {
return 'gemini-1.5-flash';
}

/**
* Gets the API client instance.
*
Expand Down Expand Up @@ -122,14 +133,15 @@ static function ( array $model ) {
*
* @since n.e.x.t
*
* @param string $model The model slug.
* @param array<string, mixed> $model_params {
* Optional. Additional model parameters. Default empty array.
*
* @type array<string, mixed> $generation_config Optional. Model generation configuration
* options. Default empty array.
* @type string|Parts|Content $system_instruction Optional. The system instruction for the
* model. Default none.
* Optional. Model parameters. Default empty array.
*
* @type string $model The model slug. By default, the service's
* default model slug is used.
* @type array<string, mixed> $generation_config Model generation configuration options.
* Default empty array.
* @type string|Parts|Content $system_instruction The system instruction for the model.
* Default none.
* @type Safety_Setting[]|array<string, mixed>[] $safety_settings Optional. The safety settings for the
* model. Default empty array.
* }
Expand All @@ -139,7 +151,14 @@ static function ( array $model ) {
* @throws InvalidArgumentException Thrown if the model slug or parameters are invalid.
* @throws Generative_AI_Exception Thrown if getting the model fails.
*/
public function get_model( string $model, array $model_params = array(), array $request_options = array() ): Generative_AI_Model {
public function get_model( array $model_params = array(), array $request_options = array() ): Generative_AI_Model {
if ( isset( $model_params['model'] ) ) {
$model = $model_params['model'];
unset( $model_params['model'] );
} else {
$model = $this->get_default_model_slug();
}

return new Google_AI_Model( $this->api, $model, $model_params, $request_options );
}
}
25 changes: 21 additions & 4 deletions includes/Services/Cache/Cached_AI_Service.php
Original file line number Diff line number Diff line change
Expand Up @@ -63,6 +63,17 @@ public function get_capabilities(): array {
return $this->service->get_capabilities();
}

/**
* Gets the default model slug to use with the service when none is provided.
*
* @since n.e.x.t
*
* @return string The default model slug.
*/
public function get_default_model_slug(): string {
return $this->service->get_default_model_slug();
}

/**
* Lists the available generative model slugs.
*
Expand All @@ -85,15 +96,21 @@ public function list_models( array $request_options = array() ): array {
*
* @since n.e.x.t
*
* @param string $model The model slug.
* @param array<string, mixed> $model_params Optional. Additional model parameters. Default empty array.
* @param array<string, mixed> $model_params {
* Optional. Model parameters. Default empty array.
*
* @type string $model The model slug. By default, the service's default model slug
* is used.
* @type array<string, mixed> $generation_config Model generation configuration options. Default empty array.
* @type string|Parts|Content $system_instruction The system instruction for the model. Default none.
* }
* @param array<string, mixed> $request_options Optional. The request options. Default empty array.
* @return Generative_AI_Model The generative model.
*
* @throws InvalidArgumentException Thrown if the model slug or parameters are invalid.
* @throws Generative_AI_Exception Thrown if getting the model fails.
*/
public function get_model( string $model, array $model_params = array(), array $request_options = array() ): Generative_AI_Model {
return $this->service->get_model( $model, $model_params, $request_options );
public function get_model( array $model_params = array(), array $request_options = array() ): Generative_AI_Model {
return $this->service->get_model( $model_params, $request_options );
}
}
21 changes: 15 additions & 6 deletions includes/Services/Contracts/Generative_AI_Service.php
Original file line number Diff line number Diff line change
Expand Up @@ -40,6 +40,15 @@ public function get_service_slug(): string;
*/
public function get_capabilities(): array;

/**
* Gets the default model slug to use with the service when none is provided.
*
* @since n.e.x.t
*
* @return string The default model slug.
*/
public function get_default_model_slug(): string;

/**
* Lists the available generative model slugs.
*
Expand All @@ -57,19 +66,19 @@ public function list_models( array $request_options = array() ): array;
*
* @since n.e.x.t
*
* @param string $model The model slug.
* @param array<string, mixed> $model_params {
* Optional. Additional model parameters. Default empty array.
* Optional. Model parameters. Default empty array.
*
* @type array<string, mixed> $generation_config Optional. Model generation configuration options. Default
* empty array.
* @type string|Parts|Content $system_instruction Optional. The system instruction for the model. Default none.
* @type string $model The model slug. By default, the service's default model slug
* is used.
* @type array<string, mixed> $generation_config Model generation configuration options. Default empty array.
* @type string|Parts|Content $system_instruction The system instruction for the model. Default none.
* }
* @param array<string, mixed> $request_options Optional. The request options. Default empty array.
* @return Generative_AI_Model The generative model.
*
* @throws InvalidArgumentException Thrown if the model slug or parameters are invalid.
* @throws Generative_AI_Exception Thrown if getting the model fails.
*/
public function get_model( string $model, array $model_params = array(), array $request_options = array() ): Generative_AI_Model;
public function get_model( array $model_params = array(), array $request_options = array() ): Generative_AI_Model;
}
Original file line number Diff line number Diff line change
Expand Up @@ -131,26 +131,9 @@ protected function handle_request( WP_REST_Request $request ): WP_REST_Response
);
}

$service = $this->services_api->get_available_service( $request['slug'] );

if ( isset( $request['model'] ) && '' !== $request['model'] ) {
$model = $request['model'];
} else {
// For now, we just use the first model available. TODO: Improve this later, e.g. by specifying a default.
try {
$model_slugs = $service->list_models();
$model = $model_slugs[0];
} catch ( Generative_AI_Exception $e ) {
throw REST_Exception::create(
'rest_cannot_determine_model',
esc_html__( 'Determining the model to use failed.', 'wp-starter-plugin' ),
500
);
}
}

$service = $this->services_api->get_available_service( $request['slug'] );
$model_params = $this->process_model_params( $request['model_params'] ?? array() );
$model = $this->get_model( $service, $model, $model_params );
$model = $this->get_model( $service, $model_params );

// Parse content data into one of the supported formats.
$content = $this->parse_content( $request['content'] );
Expand Down Expand Up @@ -190,15 +173,14 @@ protected function handle_request( WP_REST_Request $request ): WP_REST_Response
* @since n.e.x.t
*
* @param Generative_AI_Service $service The service instance to get the model from.
* @param string $model The model slug.
* @param array<string, mixed> $model_params The model parameters.
* @return Generative_AI_Model&With_Text_Generation The model.
*
* @throws REST_Exception Thrown when the model cannot be retrieved or invalid parameters are provided.
*/
protected function get_model( Generative_AI_Service $service, string $model, array $model_params ): Generative_AI_Model {
protected function get_model( Generative_AI_Service $service, array $model_params ): Generative_AI_Model {
try {
$model = $service->get_model( $model, $model_params );
$model = $service->get_model( $model_params );
} catch ( Generative_AI_Exception $e ) {
throw REST_Exception::create(
'rest_cannot_get_model',
Expand Down Expand Up @@ -279,14 +261,34 @@ protected function process_model_params( array $model_params ): array {
*/
protected function args(): array {
return array(
'model' => array(
'description' => __( 'Model slug.', 'wp-starter-plugin' ),
'type' => 'string',
),
'model_params' => array(
'description' => __( 'Model parameters.', 'wp-starter-plugin' ),
'type' => 'object',
'properties' => array(),
'properties' => array(
'model' => array(
'description' => __( 'Model slug.', 'wp-starter-plugin' ),
'type' => 'string',
),
'generation_config' => array(
'description' => __( 'Model generation configuration options.', 'wp-starter-plugin' ),
'type' => 'object',
'additionalProperties' => true,
),
'system_instruction' => array(
'description' => __( 'System instruction for the model.', 'wp-starter-plugin' ),
'type' => array( 'string', 'object', 'array' ),
'oneOf' => array(
array(
'description' => __( 'Prompt text as a string.', 'wp-starter-plugin' ),
'type' => 'string',
),
array_merge(
array( 'description' => __( 'Prompt content object.', 'wp-starter-plugin' ) ),
$this->get_content_schema( array( Content::ROLE_SYSTEM ) )
),
),
),
),
'additionalProperties' => true,
),
'content' => array(
Expand Down
15 changes: 5 additions & 10 deletions src/ai-store/chat.js
Original file line number Diff line number Diff line change
Expand Up @@ -37,11 +37,10 @@ const actions = {
* @param {string} chatId Identifier to use for the chat.
* @param {Object} options Chat options.
* @param {string} options.service AI service to use.
* @param {string} options.model Model to use.
* @param {Object} options.modelParams Model parameters.
* @param {Object} options.modelParams Model parameters (including optional model slug).
* @return {Function} Action creator.
*/
startChat( chatId, { service, model, modelParams } ) {
startChat( chatId, { service, modelParams } ) {
return async ( { dispatch, select } ) => {
if ( select.getServices() === undefined ) {
await resolveSelect( STORE_NAME ).getServices();
Expand Down Expand Up @@ -74,15 +73,13 @@ const actions = {

const session = await aiService.startChat( {
history,
model,
modelParams,
} );

dispatch.receiveChat( chatId, {
session,
service,
history,
model,
modelParams,
} );

Expand Down Expand Up @@ -151,14 +148,13 @@ const actions = {
* @param {ChatSession} options.session Chat session.
* @param {string} options.service AI service to use.
* @param {Object} options.history Chat history.
* @param {string} options.model Model to use.
* @param {Object} options.modelParams Model parameters.
* @return {Object} Action creator.
*/
receiveChat( chatId, { session, service, history, model, modelParams } ) {
receiveChat( chatId, { session, service, history, modelParams } ) {
return {
type: RECEIVE_CHAT,
payload: { chatId, session, service, history, model, modelParams },
payload: { chatId, session, service, history, modelParams },
};
},

Expand Down Expand Up @@ -207,7 +203,7 @@ const actions = {
function reducer( state = initialState, action ) {
switch ( action.type ) {
case RECEIVE_CHAT: {
const { chatId, session, service, history, model, modelParams } =
const { chatId, session, service, history, modelParams } =
action.payload;
chatSessionInstances[ chatId ] = session;
return {
Expand All @@ -216,7 +212,6 @@ function reducer( state = initialState, action ) {
...state.chatConfigs,
[ chatId ]: {
service,
model,
modelParams,
},
},
Expand Down
26 changes: 6 additions & 20 deletions src/ai-store/generative-ai-service.js
Original file line number Diff line number Diff line change
Expand Up @@ -154,11 +154,10 @@ class GenerativeAiService {
*
* @param {Object} args Arguments for generating content.
* @param {string|Object|Object[]} args.content Content data to pass to the model, including the prompt and optional history.
* @param {string} args.model Model slug.
* @param {Object} args.modelParams Model parameters.
* @param {Object} args.modelParams Model parameters (including optional model slug).
* @return {Promise<Object[]>} Model response candidates with the generated text content.
*/
async generateText( { content, model, modelParams } ) {
async generateText( { content, modelParams } ) {
if ( ! this.capabilities.includes( 'text_generation' ) ) {
throw new Error(
__(
Expand All @@ -177,7 +176,6 @@ class GenerativeAiService {
method: 'POST',
data: {
content,
model: model || '',
model_params: modelParams || {},
},
} );
Expand All @@ -193,11 +191,10 @@ class GenerativeAiService {
*
* @param {Object} args Optional arguments for starting the chat session.
* @param {Object[]} args.history Chat history.
* @param {string} args.model Model slug.
* @param {Object} args.modelParams Model parameters.
* @return {ChatSession} Chat session.
*/
startChat( { history, model, modelParams } ) {
startChat( { history, modelParams } ) {
if ( ! this.capabilities.includes( 'text_generation' ) ) {
throw new Error(
__(
Expand All @@ -207,7 +204,7 @@ class GenerativeAiService {
);
}

return new ChatSession( this, { history, model, modelParams } );
return new ChatSession( this, { history, modelParams } );
}
}

Expand All @@ -224,11 +221,10 @@ class BrowserGenerativeAiService extends GenerativeAiService {
*
* @param {Object} args Arguments for generating content.
* @param {string|Object|Object[]} args.content Content data to pass to the model, including the prompt and optional history.
* @param {string} args.model Model slug.
* @param {Object} args.modelParams Model parameters.
* @return {Promise<Object[]>} Model response candidates with the generated text content.
*/
async generateText( { content, model, modelParams } ) {
async generateText( { content, modelParams } ) {
if ( ! this.capabilities.includes( 'text_generation' ) ) {
throw new Error(
__(
Expand Down Expand Up @@ -265,13 +261,6 @@ class BrowserGenerativeAiService extends GenerativeAiService {
}
}

if ( model ) {
modelParams = {
model,
...modelParams,
};
}

const session = await window.ai.createTextSession( modelParams );
const resultText = await session.prompt( content );

Expand Down Expand Up @@ -301,12 +290,10 @@ export class ChatSession {
* @param {GenerativeAiService} service Generative AI service.
* @param {Object} options Chat options.
* @param {Object[]} options.history Chat history.
* @param {string} options.model Model slug.
* @param {Object} options.modelParams Model parameters.
*/
constructor( service, { history, model, modelParams } ) {
constructor( service, { history, modelParams } ) {
this.service = service;
this.model = model;
this.modelParams = modelParams;

if ( history ) {
Expand Down Expand Up @@ -343,7 +330,6 @@ export class ChatSession {

const candidates = await this.service.generateText( {
content: contents,
model: this.model,
modelParams: this.modelParams,
} );

Expand Down
1 change: 0 additions & 1 deletion src/chatbot/components/ChatbotApp/index.js
Original file line number Diff line number Diff line change
Expand Up @@ -55,7 +55,6 @@ export default function ChatbotApp() {
} else if ( service ) {
startChat( CHAT_ID, {
service: service.slug,
model: 'gemini-1.5-flash', // TODO: Make this configurable.
modelParams: { useWppsChatbot: true },
} );
}
Expand Down

0 comments on commit 2173382

Please sign in to comment.