diff --git a/docs/help.md b/docs/help.md
index 9db540a3..ba5d3f0d 100644
--- a/docs/help.md
+++ b/docs/help.md
@@ -122,6 +122,7 @@ Options:
   --key TEXT                      API key to use
   --save TEXT                     Save prompt with this template name
   --async                         Run prompt asynchronously
+  -u, --usage                     Show token usage
   --help                          Show this message and exit.
 ```
 
@@ -292,6 +293,7 @@ Options:
   -m, --model TEXT            Filter by model or model alias
   -q, --query TEXT            Search for logs matching this string
   -t, --truncate              Truncate long strings in output
+  -u, --usage                 Include token usage
   -r, --response              Just output the last response
   -c, --current               Show logs from the current conversation
   --cid, --conversation TEXT  Show logs for this conversation ID
diff --git a/docs/logging.md b/docs/logging.md
index 63722e01..56c0379d 100644
--- a/docs/logging.md
+++ b/docs/logging.md
@@ -159,7 +159,10 @@ CREATE TABLE [responses] (
   [response_json] TEXT,
   [conversation_id] TEXT REFERENCES [conversations]([id]),
   [duration_ms] INTEGER,
-  [datetime_utc] TEXT
+  [datetime_utc] TEXT,
+  [input_tokens] INTEGER,
+  [output_tokens] INTEGER,
+  [token_details] TEXT
 );
 CREATE VIRTUAL TABLE [responses_fts] USING FTS5 (
   [prompt],
diff --git a/docs/plugins/advanced-model-plugins.md b/docs/plugins/advanced-model-plugins.md
index f0efcfd1..9342d355 100644
--- a/docs/plugins/advanced-model-plugins.md
+++ b/docs/plugins/advanced-model-plugins.md
@@ -167,3 +167,19 @@ for prev_response in conversation.responses:
 The `response.text_or_raise()` method used there will return the text from the response or raise a `ValueError` exception if the response is an `AsyncResponse` instance that has not yet been fully resolved.
 
 This is a slightly weird hack to work around the common need to share logic for building up the `messages` list across both sync and async models.
+
+(advanced-model-plugins-usage)=
+
+## Tracking token usage
+
+Models that charge by the token should track the number of tokens used by each prompt. The ``response.set_usage()`` method can be used to record the number of tokens used by a response - these will then be made available through the Python API and logged to the SQLite database for command-line users.
+
+`response` here is the response object that is passed to `.execute()` as an argument.
+
+Call ``response.set_usage()`` at the end of your `.execute()` method. It accepts keyword arguments `input=`, `output=` and `details=` - all three are optional. `input` and `output` should be integers, and `details` should be a dictionary that provides additional information beyond the input and output token counts.
+
+This example logs 15 input tokens, 340 output tokens and notes that 37 tokens were cached:
+
+```python
+response.set_usage(input=15, output=340, details={"cached": 37})
+```
diff --git a/llm/cli.py b/llm/cli.py
index c75e0e3e..e0c8e47c 100644
--- a/llm/cli.py
+++ b/llm/cli.py
@@ -33,7 +33,7 @@
 
 from .migrations import migrate
 from .plugins import pm, load_plugins
-from .utils import mimetype_from_path, mimetype_from_string
+from .utils import mimetype_from_path, mimetype_from_string, token_usage_string
 import base64
 import httpx
 import pathlib
@@ -203,6 +203,7 @@ def cli():
 @click.option("--key", help="API key to use")
 @click.option("--save", help="Save prompt with this template name")
 @click.option("async_", "--async", is_flag=True, help="Run prompt asynchronously")
+@click.option("-u", "--usage", is_flag=True, help="Show token usage")
 def prompt(
     prompt,
     system,
@@ -220,6 +221,7 @@ def prompt(
     key,
     save,
     async_,
+    usage,
 ):
     """
     Execute a prompt
@@ -426,14 +428,24 @@ async def inner():
     except Exception as ex:
         raise click.ClickException(str(ex))
 
+    if isinstance(response, AsyncResponse):
+        response = asyncio.run(response.to_sync_response())
+
+    if usage:
+        # Show token usage to stderr in yellow
+        click.echo(
+            click.style(
+                "Token usage: {}".format(response.token_usage()), fg="yellow", bold=True
+            ),
+            err=True,
+        )
+
     # Log to the database
     if (logs_on() or log) and not no_log:
         log_path = logs_db_path()
         (log_path.parent).mkdir(parents=True, exist_ok=True)
         db = sqlite_utils.Database(log_path)
         migrate(db)
-        if isinstance(response, AsyncResponse):
-            response = asyncio.run(response.to_sync_response())
         response.log_to_db(db)
 
 
@@ -754,6 +766,9 @@ def logs_turn_off():
     responses.conversation_id,
     responses.duration_ms,
     responses.datetime_utc,
+    responses.input_tokens,
+    responses.output_tokens,
+    responses.token_details,
     conversations.name as conversation_name,
     conversations.model as conversation_model"""
 
@@ -809,6 +824,7 @@ def logs_turn_off():
 @click.option("-m", "--model", help="Filter by model or model alias")
 @click.option("-q", "--query", help="Search for logs matching this string")
 @click.option("-t", "--truncate", is_flag=True, help="Truncate long strings in output")
+@click.option("-u", "--usage", is_flag=True, help="Include token usage")
 @click.option("-r", "--response", is_flag=True, help="Just output the last response")
 @click.option(
     "current_conversation",
@@ -836,6 +852,7 @@ def logs_list(
     model,
     query,
     truncate,
+    usage,
     response,
     current_conversation,
     conversation_id,
@@ -998,6 +1015,14 @@ def logs_list(
                         )
 
             click.echo("\n## Response:\n\n{}\n".format(row["response"]))
+            if usage:
+                token_usage = token_usage_string(
+                    row["input_tokens"],
+                    row["output_tokens"],
+                    json.loads(row["token_details"]) if row["token_details"] else None,
+                )
+                if token_usage:
+                    click.echo("## Token usage:\n\n{}\n".format(token_usage))
 
 
 @cli.group(
diff --git a/llm/default_plugins/openai_models.py b/llm/default_plugins/openai_models.py
index 6234d5b1..ab33d1b4 100644
--- a/llm/default_plugins/openai_models.py
+++ b/llm/default_plugins/openai_models.py
@@ -1,6 +1,11 @@
 from llm import AsyncModel, EmbeddingModel, Model, hookimpl
 import llm
-from llm.utils import dicts_to_table_string, remove_dict_none_values, logging_client
+from llm.utils import (
+    dicts_to_table_string,
+    remove_dict_none_values,
+    logging_client,
+    simplify_usage_dict,
+)
 import click
 import datetime
 import httpx
@@ -391,6 +396,16 @@ def build_messages(self, prompt, conversation):
             messages.append({"role": "user", "content": attachment_message})
         return messages
 
+    def set_usage(self, response, usage):
+        if not usage:
+            return
+        input_tokens = usage.pop("prompt_tokens")
+        output_tokens = usage.pop("completion_tokens")
+        usage.pop("total_tokens")
+        response.set_usage(
+            input=input_tokens, output=output_tokens, details=simplify_usage_dict(usage)
+        )
+
     def get_client(self, async_=False):
         kwargs = {}
         if self.api_base:
@@ -445,6 +460,7 @@ def execute(self, prompt, stream, response, conversation=None):
         messages = self.build_messages(prompt, conversation)
         kwargs = self.build_kwargs(prompt, stream)
         client = self.get_client()
+        usage = None
         if stream:
             completion = client.chat.completions.create(
                 model=self.model_name or self.model_id,
@@ -455,6 +471,8 @@ def execute(self, prompt, stream, response, conversation=None):
             chunks = []
             for chunk in completion:
                 chunks.append(chunk)
+                if chunk.usage:
+                    usage = chunk.usage.model_dump()
                 try:
                     content = chunk.choices[0].delta.content
                 except IndexError:
@@ -469,8 +487,10 @@ def execute(self, prompt, stream, response, conversation=None):
                 stream=False,
                 **kwargs,
             )
+            usage = completion.usage.model_dump()
             response.response_json = remove_dict_none_values(completion.model_dump())
             yield completion.choices[0].message.content
+        self.set_usage(response, usage)
         response._prompt_json = redact_data({"messages": messages})
 
 
@@ -493,6 +513,7 @@ async def execute(
         messages = self.build_messages(prompt, conversation)
         kwargs = self.build_kwargs(prompt, stream)
         client = self.get_client(async_=True)
+        usage = None
         if stream:
             completion = await client.chat.completions.create(
                 model=self.model_name or self.model_id,
@@ -502,6 +523,8 @@ async def execute(
             )
             chunks = []
             async for chunk in completion:
+                if chunk.usage:
+                    usage = chunk.usage.model_dump()
                 chunks.append(chunk)
                 try:
                     content = chunk.choices[0].delta.content
@@ -518,7 +541,9 @@ async def execute(
                 **kwargs,
             )
             response.response_json = remove_dict_none_values(completion.model_dump())
+            usage = completion.usage.model_dump()
             yield completion.choices[0].message.content
+        self.set_usage(response, usage)
         response._prompt_json = redact_data({"messages": messages})
 
 
diff --git a/llm/migrations.py b/llm/migrations.py
index 91da6429..b8ac8b13 100644
--- a/llm/migrations.py
+++ b/llm/migrations.py
@@ -227,3 +227,10 @@ def m012_attachments_tables(db):
         ),
         pk=("response_id", "attachment_id"),
     )
+
+
+@migration
+def m013_usage(db):
+    db["responses"].add_column("input_tokens", int)
+    db["responses"].add_column("output_tokens", int)
+    db["responses"].add_column("token_details", str)
diff --git a/llm/models.py b/llm/models.py
index c160798b..5bf9f11c 100644
--- a/llm/models.py
+++ b/llm/models.py
@@ -18,7 +18,7 @@
     Set,
     Union,
 )
-from .utils import mimetype_from_path, mimetype_from_string
+from .utils import mimetype_from_path, mimetype_from_string, token_usage_string
 from abc import ABC, abstractmethod
 import json
 from pydantic import BaseModel
@@ -208,6 +208,20 @@ def __init__(
         self._start: Optional[float] = None
         self._end: Optional[float] = None
         self._start_utcnow: Optional[datetime.datetime] = None
+        self.input_tokens: Optional[int] = None
+        self.output_tokens: Optional[int] = None
+        self.token_details: Optional[dict] = None
+
+    def set_usage(
+        self,
+        *,
+        input: Optional[int] = None,
+        output: Optional[int] = None,
+        details: Optional[dict] = None,
+    ):
+        self.input_tokens = input
+        self.output_tokens = output
+        self.token_details = details
 
     @classmethod
     def from_row(cls, db, row):
@@ -246,6 +260,11 @@ def from_row(cls, db, row):
         ]
         return response
 
+    def token_usage(self) -> str:
+        return token_usage_string(
+            self.input_tokens, self.output_tokens, self.token_details
+        )
+
     def log_to_db(self, db):
         conversation = self.conversation
         if not conversation:
@@ -272,11 +291,16 @@ def log_to_db(self, db):
                 for key, value in dict(self.prompt.options).items()
                 if value is not None
             },
-            "response": self.text(),
+            "response": self.text_or_raise(),
             "response_json": self.json(),
             "conversation_id": conversation.id,
             "duration_ms": self.duration_ms(),
             "datetime_utc": self.datetime_utc(),
+            "input_tokens": self.input_tokens,
+            "output_tokens": self.output_tokens,
+            "token_details": (
+                json.dumps(self.token_details) if self.token_details else None
+            ),
         }
         db["responses"].insert(response)
         # Persist any attachments - loop through with index
@@ -439,6 +463,9 @@ async def to_sync_response(self) -> Response:
         response._end = self._end
         response._start = self._start
         response._start_utcnow = self._start_utcnow
+        response.input_tokens = self.input_tokens
+        response.output_tokens = self.output_tokens
+        response.token_details = self.token_details
         return response
 
     @classmethod
diff --git a/llm/utils.py b/llm/utils.py
index d2618dd4..e9853185 100644
--- a/llm/utils.py
+++ b/llm/utils.py
@@ -127,3 +127,29 @@ def logging_client() -> httpx.Client:
         transport=_LogTransport(httpx.HTTPTransport()),
         event_hooks={"request": [_no_accept_encoding], "response": [_log_response]},
     )
+
+
+def simplify_usage_dict(d):
+    # Recursively remove keys with value 0 and empty dictionaries
+    def remove_empty_and_zero(obj):
+        if isinstance(obj, dict):
+            cleaned = {
+                k: remove_empty_and_zero(v)
+                for k, v in obj.items()
+                if v != 0 and v != {}
+            }
+            return {k: v for k, v in cleaned.items() if v is not None and v != {}}
+        return obj
+
+    return remove_empty_and_zero(d) or {}
+
+
+def token_usage_string(input_tokens, output_tokens, token_details) -> str:
+    bits = []
+    if input_tokens is not None:
+        bits.append(f"{format(input_tokens, ',')} input")
+    if output_tokens is not None:
+        bits.append(f"{format(output_tokens, ',')} output")
+    if token_details:
+        bits.append(json.dumps(token_details))
+    return ", ".join(bits)
diff --git a/tests/conftest.py b/tests/conftest.py
index 6fb8bf75..447e1caa 100644
--- a/tests/conftest.py
+++ b/tests/conftest.py
@@ -66,13 +66,17 @@ def enqueue(self, messages):
 
     def execute(self, prompt, stream, response, conversation):
         self.history.append((prompt, stream, response, conversation))
+        gathered = []
         while True:
             try:
                 messages = self._queue.pop(0)
-                yield from messages
+                for message in messages:
+                    gathered.append(message)
+                    yield message
                 break
             except IndexError:
                 break
+        response.set_usage(input=len(prompt.prompt.split()), output=len(gathered))
 
 
 class AsyncMockModel(llm.AsyncModel):
diff --git a/tests/test_chat.py b/tests/test_chat.py
index 285fa476..1a31f290 100644
--- a/tests/test_chat.py
+++ b/tests/test_chat.py
@@ -62,6 +62,9 @@ def test_chat_basic(mock_model, logs_db):
             "conversation_id": conversation_id,
             "duration_ms": ANY,
             "datetime_utc": ANY,
+            "input_tokens": 1,
+            "output_tokens": 1,
+            "token_details": None,
         },
         {
             "id": ANY,
@@ -75,6 +78,9 @@ def test_chat_basic(mock_model, logs_db):
             "conversation_id": conversation_id,
             "duration_ms": ANY,
             "datetime_utc": ANY,
+            "input_tokens": 2,
+            "output_tokens": 1,
+            "token_details": None,
         },
     ]
     # Now continue that conversation
@@ -116,6 +122,9 @@ def test_chat_basic(mock_model, logs_db):
             "conversation_id": conversation_id,
             "duration_ms": ANY,
             "datetime_utc": ANY,
+            "input_tokens": 1,
+            "output_tokens": 1,
+            "token_details": None,
         }
     ]
 
@@ -153,6 +162,9 @@ def test_chat_system(mock_model, logs_db):
             "conversation_id": ANY,
             "duration_ms": ANY,
             "datetime_utc": ANY,
+            "input_tokens": 1,
+            "output_tokens": 1,
+            "token_details": None,
         }
     ]
 
@@ -181,6 +193,9 @@ def test_chat_options(mock_model, logs_db):
             "conversation_id": ANY,
             "duration_ms": ANY,
             "datetime_utc": ANY,
+            "input_tokens": 1,
+            "output_tokens": 1,
+            "token_details": None,
         }
     ]
 
diff --git a/tests/test_cli_openai_models.py b/tests/test_cli_openai_models.py
index b65ad078..3d0a7c16 100644
--- a/tests/test_cli_openai_models.py
+++ b/tests/test_cli_openai_models.py
@@ -147,7 +147,8 @@ def test_only_gpt4_audio_preview_allows_mp3_or_wav(httpx_mock, model, filetype):
 
 
 @pytest.mark.parametrize("async_", (False, True))
-def test_gpt4o_mini_sync_and_async(monkeypatch, tmpdir, httpx_mock, async_):
+@pytest.mark.parametrize("usage", (None, "-u", "--usage"))
+def test_gpt4o_mini_sync_and_async(monkeypatch, tmpdir, httpx_mock, async_, usage):
     user_path = tmpdir / "user_dir"
     log_db = user_path / "logs.db"
     monkeypatch.setenv("LLM_USER_PATH", str(user_path))
@@ -173,21 +174,25 @@ def test_gpt4o_mini_sync_and_async(monkeypatch, tmpdir, httpx_mock, async_):
                 }
             ],
             "usage": {
-                "prompt_tokens": 10,
-                "completion_tokens": 2,
+                "prompt_tokens": 1000,
+                "completion_tokens": 2000,
                 "total_tokens": 12,
             },
             "system_fingerprint": "fp_49254d0e9b",
         },
         headers={"Content-Type": "application/json"},
     )
-    runner = CliRunner()
+    runner = CliRunner(mix_stderr=False)
     args = ["-m", "gpt-4o-mini", "--key", "x", "--no-stream"]
+    if usage:
+        args.append(usage)
     if async_:
         args.append("--async")
     result = runner.invoke(cli, args, catch_exceptions=False)
     assert result.exit_code == 0
     assert result.output == "Ho ho ho\n"
+    if usage:
+        assert result.stderr == "Token usage: 1,000 input, 2,000 output\n"
     # Confirm it was correctly logged
     assert log_db.exists()
     db = sqlite_utils.Database(str(log_db))
diff --git a/tests/test_llm.py b/tests/test_llm.py
index 0e54cc91..b83ff842 100644
--- a/tests/test_llm.py
+++ b/tests/test_llm.py
@@ -37,6 +37,8 @@ def log_path(user_path):
             "model": "davinci",
             "datetime_utc": (start + datetime.timedelta(seconds=i)).isoformat(),
             "conversation_id": "abc123",
+            "input_tokens": 2,
+            "output_tokens": 5,
         }
         for i in range(100)
     )
@@ -46,9 +48,12 @@ def log_path(user_path):
 datetime_re = re.compile(r"\d{4}-\d{2}-\d{2}T\d{2}:\d{2}:\d{2}")
 
 
-def test_logs_text(log_path):
+@pytest.mark.parametrize("usage", (False, True))
+def test_logs_text(log_path, usage):
     runner = CliRunner()
     args = ["logs", "-p", str(log_path)]
+    if usage:
+        args.append("-u")
     result = runner.invoke(cli, args, catch_exceptions=False)
     assert result.exit_code == 0
     output = result.output
@@ -64,18 +69,24 @@ def test_logs_text(log_path):
         "system\n\n"
         "## Response:\n\n"
         "response\n\n"
+    ) + ("## Token usage:\n\n2 input, 5 output\n\n" if usage else "") + (
         "# YYYY-MM-DDTHH:MM:SS    conversation: abc123\n\n"
         "Model: **davinci**\n\n"
         "## Prompt:\n\n"
         "prompt\n\n"
         "## Response:\n\n"
         "response\n\n"
+    ) + (
+        "## Token usage:\n\n2 input, 5 output\n\n" if usage else ""
+    ) + (
         "# YYYY-MM-DDTHH:MM:SS    conversation: abc123\n\n"
         "Model: **davinci**\n\n"
         "## Prompt:\n\n"
         "prompt\n\n"
         "## Response:\n\n"
         "response\n\n"
+    ) + (
+        "## Token usage:\n\n2 input, 5 output\n\n" if usage else ""
     )
 
 
diff --git a/tests/test_migrate.py b/tests/test_migrate.py
index 1c68de93..d1da5571 100644
--- a/tests/test_migrate.py
+++ b/tests/test_migrate.py
@@ -17,6 +17,9 @@
     "conversation_id": str,
     "duration_ms": int,
     "datetime_utc": str,
+    "input_tokens": int,
+    "output_tokens": int,
+    "token_details": str,
 }
 
 
diff --git a/tests/test_utils.py b/tests/test_utils.py
new file mode 100644
index 00000000..85ed54ae
--- /dev/null
+++ b/tests/test_utils.py
@@ -0,0 +1,42 @@
+import pytest
+from llm.utils import simplify_usage_dict
+
+
+@pytest.mark.parametrize(
+    "input_data,expected_output",
+    [
+        (
+            {
+                "prompt_tokens_details": {"cached_tokens": 0, "audio_tokens": 0},
+                "completion_tokens_details": {
+                    "reasoning_tokens": 0,
+                    "audio_tokens": 1,
+                    "accepted_prediction_tokens": 0,
+                    "rejected_prediction_tokens": 0,
+                },
+            },
+            {"completion_tokens_details": {"audio_tokens": 1}},
+        ),
+        (
+            {
+                "details": {"tokens": 5, "audio_tokens": 2},
+                "more_details": {"accepted_tokens": 3},
+            },
+            {
+                "details": {"tokens": 5, "audio_tokens": 2},
+                "more_details": {"accepted_tokens": 3},
+            },
+        ),
+        ({"details": {"tokens": 0, "audio_tokens": 0}, "more_details": {}}, {}),
+        ({"level1": {"level2": {"value": 0, "another_value": {}}}}, {}),
+        (
+            {
+                "level1": {"level2": {"value": 0, "another_value": 1}},
+                "level3": {"empty_dict": {}, "valid_token": 10},
+            },
+            {"level1": {"level2": {"another_value": 1}}, "level3": {"valid_token": 10}},
+        ),
+    ],
+)
+def test_simplify_usage_dict(input_data, expected_output):
+    assert simplify_usage_dict(input_data) == expected_output