Skip to content

Commit

Permalink
new OpenAIRetryServiceAdapter in openai-client package
Browse files Browse the repository at this point in the history
  • Loading branch information
phelps-sg committed Jun 10, 2023
1 parent c302744 commit 72f8397
Show file tree
Hide file tree
Showing 2 changed files with 44 additions and 1 deletion.
Original file line number Diff line number Diff line change
@@ -0,0 +1,42 @@
package io.cequence.openaiscala.service

import akka.actor.{ActorSystem, Scheduler}
import io.cequence.openaiscala.RetryHelpers
import io.cequence.openaiscala.RetryHelpers.RetrySettings

import scala.concurrent.{ExecutionContext, Future}

private class OpenAIRetryServiceAdapter(
underlying: OpenAIService,
val actorSystem: ActorSystem,
implicit val ec: ExecutionContext,
implicit val retrySettings: RetrySettings,
implicit val scheduler: Scheduler
) extends OpenAIServiceWrapper
with RetryHelpers {

override def close: Unit =
underlying.close

override protected def wrap[T](
fun: OpenAIService => Future[T]
): Future[T] = {
fun(underlying).retryOnFailure
}
}

object OpenAIRetryServiceAdapter {
def apply(underlying: OpenAIService)(implicit
ec: ExecutionContext,
retrySettings: RetrySettings,
scheduler: Scheduler,
actorSystem: ActorSystem
): OpenAIService =
new OpenAIRetryServiceAdapter(
underlying,
actorSystem,
ec,
retrySettings,
scheduler
)
}
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@ import io.cequence.openaiscala.{OpenAIScalaTokenCountExceededException, StackWal

import scala.concurrent.{ExecutionContext, Future}

@deprecated("Use io.cequence.openaiscala.RetryHelpers")
@deprecated("Use openai-client:io.ceqeunce.openaiscala.service.OpenAIRetryServiceAdapter or openai-client:io.cequence.openaiscala.RetryHelpers")
private class OpenAIRetryServiceAdapter(
underlying: OpenAIService,
maxAttempts: Int,
Expand All @@ -16,6 +16,7 @@ private class OpenAIRetryServiceAdapter(
override protected def wrap[T](
fun: OpenAIService => Future[T]
): Future[T] = {
fun(underlying)
// need to use StackWalker to get the caller function name
fun.toString()
val functionName = StackWalkerUtil.functionName(2).get()
Expand Down

0 comments on commit 72f8397

Please sign in to comment.