-
Notifications
You must be signed in to change notification settings - Fork 89
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
Merge pull request #116 from brainlid/me-add-openai-image-endpoint
add openai image endpoint support (aka DALL-E-2 & DALL-E-3)
- Loading branch information
Showing
6 changed files
with
813 additions
and
0 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,115 @@ | ||
defmodule LangChain.Images do | ||
@moduledoc """ | ||
Functions for working with `LangChain.GeneratedImage` files. | ||
""" | ||
require Logger | ||
alias LangChain.LangChainError | ||
alias LangChain.Images.GeneratedImage | ||
|
||
@doc """ | ||
Save a list of `%GeneratedImage{}` images. | ||
Pipe friendly for handling the result a `LangChain.Images.OpenAIImage.call/1` | ||
where it handles a success or passes the error result through. | ||
""" | ||
@spec save_images([GeneratedImage.t()], path :: String.t(), filename_prefix :: String.t()) :: | ||
{:ok, [String.t()]} | {:error, String.t()} | ||
def save_images({:ok, images}, path, filename_prefix) do | ||
# unwrap the tuple | ||
save_images(images, path, filename_prefix) | ||
end | ||
|
||
def save_images(images, path, filename_prefix) when is_list(images) do | ||
try do | ||
Enum.with_index(images, fn %GeneratedImage{} = image, idx -> | ||
index_as_string = to_string(idx + 1) | ||
number = String.pad_leading(index_as_string, 2, "0") | ||
filename = filename_prefix <> number <> ".#{to_string(image.image_type)}" | ||
|
||
# if it saved successfully, return the generated filename | ||
case save_to_file(image, Path.join([path, filename])) do | ||
:ok -> | ||
filename | ||
|
||
{:error, reason} -> | ||
raise LangChainError, "File save error. Reason: #{inspect(reason)}" | ||
end | ||
end) | ||
else | ||
result when is_list(result) -> | ||
{:ok, result} | ||
rescue | ||
err in [LangChainError] -> | ||
{:error, err.message} | ||
end | ||
end | ||
|
||
def save_images({:error, _reason} = error, _path, _filename_prefix), do: error | ||
|
||
@doc """ | ||
Save the generated image file to a local directory. If the GeneratedFile is an | ||
URL, it is first downloaded then saved. If the is a Base64 encoded image, it | ||
is decoded and saved. | ||
""" | ||
@spec save_to_file(GeneratedImage.t(), String.t()) :: :ok | {:error, String.t()} | ||
def save_to_file(%GeneratedImage{type: :url} = image, target_path) do | ||
# When a generated image is type `:url`, the content is the URL | ||
case Req.get(image.content) do | ||
{:ok, %Req.Response{body: body, status: 200}} -> | ||
# Save the file locally | ||
do_write_to_file(body, target_path) | ||
|
||
{:ok, %Req.Response{status: 404}} -> | ||
{:error, "Image file not found"} | ||
|
||
{:ok, %Req.Response{status: 500}} -> | ||
{:error, "Failed with server error 500"} | ||
|
||
{:error, reason} -> | ||
# Handle error | ||
Logger.error("Failed to download image: #{inspect(reason)}") | ||
{:error, reason} | ||
end | ||
end | ||
|
||
def save_to_file(%GeneratedImage{type: :base64} = image, target_path) do | ||
case Base.decode64(image.content) do | ||
{:ok, binary_data} -> | ||
do_write_to_file(binary_data, target_path) | ||
|
||
:error -> | ||
{:error, "Failed to base64 decode image data"} | ||
end | ||
end | ||
|
||
# write the contents to the file | ||
@spec do_write_to_file(binary(), String.t()) :: :ok | {:error, String.t()} | ||
defp do_write_to_file(data, target_path) do | ||
case File.write(target_path, data) do | ||
:ok -> | ||
:ok | ||
|
||
{:error, :eacces} -> | ||
{:error, "Missing write permissions for the parent directory"} | ||
|
||
{:error, :eexist} -> | ||
{:error, "A file or directory already exists"} | ||
|
||
{:error, :enoent} -> | ||
{:error, "File path is invalid"} | ||
|
||
{:error, :enospc} -> | ||
{:error, "No space left on device"} | ||
|
||
{:error, :enotdir} -> | ||
{:error, "Part of path is not a directory"} | ||
|
||
{:error, reason} -> | ||
Logger.error( | ||
"Failed to save base64 image to file #{inspect(target_path)}. Reason: #{inspect(reason)}" | ||
) | ||
|
||
{:error, "Unrecognized error reason encountered: #{inspect(reason)}"} | ||
end | ||
end | ||
end |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,71 @@ | ||
defmodule LangChain.Images.GeneratedImage do | ||
@moduledoc """ | ||
Represents a generated image where we have either the base64 encoded contents | ||
or a temporary URL to it. | ||
## Example | ||
Created when an image generation request completes and we have an image. | ||
GeneratedImage.new!(%{ | ||
image_type: :png, | ||
type: :url, | ||
content: "https://example.com/my_image.png", | ||
prompt: "The prompt used for image generation" | ||
}) | ||
""" | ||
use Ecto.Schema | ||
import Ecto.Changeset | ||
require Logger | ||
alias __MODULE__ | ||
alias LangChain.LangChainError | ||
|
||
@primary_key false | ||
embedded_schema do | ||
field :image_type, Ecto.Enum, values: [:png, :jpg], default: :png | ||
field :type, Ecto.Enum, values: [:base64, :url], default: :url | ||
# When a :url, content is the URL. When base64, content is the encoded data. | ||
field :content, :string | ||
|
||
# The prompt used when generating the image. It may have been altered by the | ||
# LLM from the original request. | ||
field :prompt, :string | ||
field :metadata, :map | ||
field :created_at, :utc_datetime | ||
end | ||
|
||
@type t :: %GeneratedImage{} | ||
|
||
@update_fields [:image_type, :type, :content, :prompt, :metadata, :created_at] | ||
@create_fields @update_fields | ||
|
||
@doc """ | ||
Build a new GeneratedImage and return an `:ok`/`:error` tuple with the result. | ||
""" | ||
@spec new(attrs :: map()) :: {:ok, t()} | {:error, Ecto.Changeset.t()} | ||
def new(attrs \\ %{}) do | ||
%GeneratedImage{} | ||
|> cast(attrs, @create_fields) | ||
|> common_validations() | ||
|> apply_action(:insert) | ||
end | ||
|
||
@doc """ | ||
Build a new GeneratedImage and return it or raise an error if invalid. | ||
""" | ||
@spec new!(attrs :: map()) :: t() | no_return() | ||
def new!(attrs \\ %{}) do | ||
case new(attrs) do | ||
{:ok, message} -> | ||
message | ||
|
||
{:error, changeset} -> | ||
raise LangChainError, changeset | ||
end | ||
end | ||
|
||
defp common_validations(changeset) do | ||
changeset | ||
|> validate_required([:image_type, :type, :content]) | ||
end | ||
end |
Oops, something went wrong.