Skip to content

Commit

Permalink
Fix token bucket issue
Browse files Browse the repository at this point in the history
  • Loading branch information
johnjosephhorton committed Sep 20, 2024
1 parent 8592a47 commit d2495e4
Show file tree
Hide file tree
Showing 4 changed files with 50 additions and 30 deletions.
49 changes: 20 additions & 29 deletions edsl/language_models/LanguageModel.py
Original file line number Diff line number Diff line change
Expand Up @@ -164,10 +164,6 @@ class LanguageModel(
None # This should be something like ["choices", 0, "message", "content"]
)
__rate_limits = None
__default_rate_limits = {
"rpm": 10_000,
"tpm": 2_000_000,
} # TODO: Use the OpenAI Teir 1 rate limits
_safety_factor = 0.8

def __init__(
Expand All @@ -181,6 +177,7 @@ def __init__(
self.remote = False
self.omit_system_prompt_if_empty = omit_system_prompt_if_empty_string

# self._rpm / _tpm comes from the class
if rpm is not None:
self._rpm = rpm

Expand Down Expand Up @@ -289,35 +286,40 @@ def set_rate_limits(self, rpm=None, tpm=None) -> None:
>>> m.RPM
100
"""
self._set_rate_limits(rpm=rpm, tpm=tpm)
if rpm is not None:
self._rpm = rpm
if tpm is not None:
self._tpm = tpm
return None
# self._set_rate_limits(rpm=rpm, tpm=tpm)

def _set_rate_limits(self, rpm=None, tpm=None) -> None:
"""Set the rate limits for the model.
# def _set_rate_limits(self, rpm=None, tpm=None) -> None:
# """Set the rate limits for the model.

If the model does not have rate limits, use the default rate limits."""
if rpm is not None and tpm is not None:
self.__rate_limits = {"rpm": rpm, "tpm": tpm}
return
# If the model does not have rate limits, use the default rate limits."""
# if rpm is not None and tpm is not None:
# self.__rate_limits = {"rpm": rpm, "tpm": tpm}
# return

if self.__rate_limits is None:
if hasattr(self, "get_rate_limits"):
self.__rate_limits = self.get_rate_limits()
else:
self.__rate_limits = self.__default_rate_limits
# if self.__rate_limits is None:
# if hasattr(self, "get_rate_limits"):
# self.__rate_limits = self.get_rate_limits()
# else:
# self.__rate_limits = self.__default_rate_limits

@property
def RPM(self):
"""Model's requests-per-minute limit."""
# self._set_rate_limits()
# return self._safety_factor * self.__rate_limits["rpm"]
return self.rpm
return self._rpm

@property
def TPM(self):
"""Model's tokens-per-minute limit."""
# self._set_rate_limits()
# return self._safety_factor * self.__rate_limits["tpm"]
return self.tpm
return self._tpm

@property
def rpm(self):
Expand All @@ -335,17 +337,6 @@ def tpm(self):
def tpm(self, value):
self._tpm = value

@property
def TPM(self):
"""Model's tokens-per-minute limit.
>>> m = LanguageModel.example()
>>> m.TPM > 0
True
"""
self._set_rate_limits()
return self._safety_factor * self.__rate_limits["tpm"]

@staticmethod
def _overide_default_parameters(passed_parameter_dict, default_parameter_dict):
"""Return a dictionary of parameters, with passed parameters taking precedence over defaults.
Expand Down
7 changes: 7 additions & 0 deletions tests/base/test_Base.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,7 @@ class SaveLoadFail(Warning):

class TestBaseModels:
def test_register_subclasses_meta(self):

for key, value in RegisterSubclassesMeta.get_registry().items():
assert key in [
"Result",
Expand All @@ -30,6 +31,12 @@ def test_register_subclasses_meta(self):
"Cache",
"Notebook",
"ModelList",
"FileStore",
"HTMLFileStore",
"CSVFileStore",
"PDFFileStore",
"PNGFileStore",
"SQLiteFileStore",
]

methods = [
Expand Down
13 changes: 13 additions & 0 deletions tests/language_models/test_LanguageModel.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,19 @@ class TestLanguageModel(unittest.TestCase):
def setUp(self):
pass

def test_tokens(self):
import random

random_tpm = random.randint(0, 100)
random_tpm = random.randint(0, 100)
m = LanguageModel.example()
m.set_rate_limits(tpm=random_tpm, rpm=random_tpm)
self.assertEqual(m.tpm, random_tpm)
self.assertEqual(m.rpm, random_tpm)

m.rpm = 45
self.assertEqual(m.rpm, 45)

def test_execute_model_call(self):
from edsl.data.Cache import Cache

Expand Down
11 changes: 10 additions & 1 deletion tests/serialization/test_serialization.py
Original file line number Diff line number Diff line change
Expand Up @@ -101,7 +101,16 @@ def test_serialization_coverage():

classes_not_covered = (classes_to_cover - data_classes) - set(
# We don't need the base Question or QuestionAddTwoNumbers (a test instance of QuestionFunctional)
["QuestionBase", "QuestionAddTwoNumbers"]
[
"QuestionBase",
"QuestionAddTwoNumbers",
"FileStore",
"HTMLFileStore",
"CSVFileStore",
"PDFFileStore",
"PNGFileStore",
"SQLiteFileStore",
]
)

assert (
Expand Down

0 comments on commit d2495e4

Please sign in to comment.