Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add support for bytes formatted images to submit_image and submit_gallery #1891

Open
wants to merge 21 commits into
base: master
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions AUTHORS.rst
Original file line number Diff line number Diff line change
Expand Up @@ -81,4 +81,5 @@ Source Contributors
- Josh Kim `@jsk56143 <https://github.com/jsk56143>`_
- Rolf Campbell `@endlisnis <https://github.com/endlisnis>`_
- zacc `@zacc <https://github.com/zacc>`_
- Connor Colabella `@redowul <https://github.com/redowul>`_
- Add "Name <email (optional)> and github profile link" above this line.
128 changes: 90 additions & 38 deletions praw/models/reddit/subreddit.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@
import socket
from copy import deepcopy
from csv import writer
from io import StringIO
from io import BytesIO, StringIO
from json import dumps, loads
from os.path import basename, dirname, isfile, join
from typing import TYPE_CHECKING, Any, Dict, Generator, Iterator, List, Optional, Union
Expand All @@ -13,7 +13,7 @@
from xml.etree.ElementTree import XML

import websocket
from prawcore import Redirect
from prawcore import Redirect, RequestException
from prawcore.exceptions import ServerError
from requests.exceptions import HTTPError

Expand Down Expand Up @@ -209,13 +209,22 @@ def _subreddit_list(*, other_subreddits, subreddit):

@staticmethod
def _validate_gallery(images):
for image in images:
image_path = image.get("image_path", "")
if image_path:
if not isfile(image_path):
raise TypeError(f"{image_path!r} is not a valid image path.")
for index, image in enumerate(images):
image_path = image.get("image_path")
image_fp = image.get("image_fp")
if image_path is not None and image_fp is None:
if isinstance(image_path, str):
if not isfile(image_path):
raise TypeError(f"{image_path} is not a valid file path.")
elif image_path is None and image_fp is not None:
if not isinstance(image_fp, bytes):
raise TypeError(
f"'image_fp' dictionary value at index {index} contains an invalid bytes object."
) # do not log bytes value, it is long and not human readable
else:
raise TypeError("'image_path' is required.")
raise TypeError(
f"Values for keys image_path and image_fp are null for dictionary at index {index}."
)
if not len(image.get("caption", "")) <= 180:
raise TypeError("Caption must be 180 characters or less.")

Expand Down Expand Up @@ -643,24 +652,34 @@ def _submit_media(
url = ws_update["payload"]["redirect"]
return self._reddit.submission(url=url)

def _read_and_post_media(self, media_path, upload_url, upload_data):
with open(media_path, "rb") as media:
def _read_and_post_media(self, media_path, media_fp, upload_url, upload_data):
response = None
if media_path is not None and media_fp is None:
with open(media_path, "rb") as media:
response = self._reddit._core._requestor._http.post(
upload_url, data=upload_data, files={"file": media}
)
elif media_path is None and media_fp is not None:
response = self._reddit._core._requestor._http.post(
upload_url, data=upload_data, files={"file": media}
upload_url, data=upload_data, files={"file": BytesIO(media_fp)}
)
return response

def _upload_media(
self,
*,
expected_mime_prefix: Optional[str] = None,
media_path: str,
media_path: Optional[str] = None,
media_fp: Optional[bytes] = None,
mime_type: Optional[str] = None,
upload_type: str = "link",
):
"""Upload media and return its URL and a websocket (Undocumented endpoint).

:param expected_mime_prefix: If provided, enforce that the media has a mime type
that starts with the provided prefix.
:param mime_type: The mime type of the media, supplement of ``media_fp``.
Redundant when ``media_path`` has an appropriate value. (default: ``None``).
:param upload_type: One of ``"link"``, ``"gallery"'', or ``"selfpost"``
(default: ``"link"``).

Expand All @@ -669,23 +688,30 @@ def _upload_media(
finished, or it can be ignored.

"""
if media_path is None:
media_path = join(
dirname(dirname(dirname(__file__))), "images", "PRAW logo.png"
)

file_name = basename(media_path).lower()
file_extension = file_name.rpartition(".")[2]
mime_type = {
file_name = None
mime_types = {
"png": "image/png",
"mov": "video/quicktime",
"mp4": "video/mp4",
"jpg": "image/jpeg",
"jpeg": "image/jpeg",
"gif": "image/gif",
}.get(
file_extension, "image/jpeg"
) # default to JPEG
}
if media_path is None and media_fp is None:
media_path = join(
dirname(dirname(dirname(__file__))), "images", "PRAW logo.png"
)
if media_path is not None and media_fp is None:
file_name = basename(media_path).lower()
file_extension = file_name.rpartition(".")[2]
mime_type = mime_types.get(file_extension, "image/jpeg") # default to JPEG
elif media_path is None and media_fp is not None:
if isinstance(media_fp, bytes):
mime_type = mime_types.get(
mime_type.partition("/")[1], "image/jpeg"
) # default to JPEG
else:
raise TypeError("media_fp is not of type bytes.")
if (
expected_mime_prefix is not None
and mime_type.partition("/")[0] != expected_mime_prefix
Expand All @@ -698,12 +724,22 @@ def _upload_media(

url = API_PATH["media_asset"]
# until we learn otherwise, assume this request always succeeds
upload_response = self._reddit.post(url, data=img_data)
upload_lease = upload_response["args"]
upload_url = f"https:{upload_lease['action']}"
upload_data = {item["name"]: item["value"] for item in upload_lease["fields"]}

response = self._read_and_post_media(media_path, upload_url, upload_data)
upload_response = None
upload_lease = None
upload_url = None
upload_data = None
try:
upload_response = self._reddit.post(url, data=img_data)
upload_lease = upload_response["args"]
upload_url = f"https:{upload_lease['action']}"
upload_data = {
item["name"]: item["value"] for item in upload_lease["fields"]
}
except RequestException:
pass
response = self._read_and_post_media(
media_path, media_fp, upload_url, upload_data
)
if not response.ok:
self._parse_xml_response(response)
try:
Expand Down Expand Up @@ -1040,7 +1076,7 @@ def submit(
def submit_gallery(
self,
title: str,
images: List[Dict[str, str]],
images: List[Dict[str, str]] = None,
*,
collection_id: Optional[str] = None,
discussion_type: Optional[str] = None,
Expand All @@ -1053,9 +1089,12 @@ def submit_gallery(
"""Add an image gallery submission to the subreddit.

:param title: The title of the submission.
:param images: The images to post in dict with the following structure:
``{"image_path": "path", "caption": "caption", "outbound_url": "url"}``,
only ``image_path`` is required.
:param images: The images to post in dict with one of the following two
structures: ``{"image_path": "path", "caption": "caption", "outbound_url":
"url"}`` and ``{"image_fp": "file_pointer", "caption": "caption",
"mime_type": "image/png", "outbound_url": "url"}``, only ``image_path`` is
required for the former structure while ``image_fp`` and ``mime_type`` are
required for the latter.
:param collection_id: The UUID of a :class:`.Collection` to add the
newly-submitted post to.
:param discussion_type: Set to ``"CHAT"`` to enable live discussion instead of
Expand Down Expand Up @@ -1132,7 +1171,9 @@ def submit_gallery(
"outbound_url": image.get("outbound_url", ""),
"media_id": self._upload_media(
expected_mime_prefix="image",
media_path=image["image_path"],
media_path=image.get("image_path"),
media_fp=image.get("image_fp"),
mime_type=image.get("mime_type"),
upload_type="gallery",
)[0],
}
Expand Down Expand Up @@ -1162,8 +1203,10 @@ def submit_gallery(
def submit_image(
self,
title: str,
image_path: str,
*,
image_path: Optional[str] = None,
image_fp: Optional[bytes] = None,
mime_type: Optional[str] = None,
collection_id: Optional[str] = None,
discussion_type: Optional[str] = None,
flair_id: Optional[str] = None,
Expand All @@ -1185,7 +1228,11 @@ def submit_image(
:param flair_text: If the template's ``flair_text_editable`` value is ``True``,
this value will set a custom text (default: ``None``). ``flair_id`` is
required when ``flair_text`` is provided.
:param image_path: The path to an image, to upload and post.
:param image_path: The path to an image, to upload and post. (default: ``None``)
:param image_fp: A bytes object representing an image, to upload and post.
(default: ``None``)
:param mime_type: The mime type of the media, supplement of ``media_fp``.
Redundant when ``media_path`` has an appropriate value. (default: ``None``).
:param nsfw: Whether the submission should be marked NSFW (default: ``False``).
:param resubmit: When ``False``, an error will occur if the URL has already been
submitted (default: ``True``).
Expand Down Expand Up @@ -1255,8 +1302,12 @@ def submit_image(
data[key] = value

image_url, websocket_url = self._upload_media(
expected_mime_prefix="image", media_path=image_path
expected_mime_prefix="image",
media_path=image_path,
media_fp=image_fp,
mime_type=mime_type,
)

data.update(kind="image", url=image_url)
if without_websockets:
websocket_url = None
Expand Down Expand Up @@ -1480,7 +1531,8 @@ def submit_video(
data[key] = value

video_url, websocket_url = self._upload_media(
expected_mime_prefix="video", media_path=video_path
expected_mime_prefix="video",
media_path=video_path,
)
data.update(
kind="videogif" if videogif else "video",
Expand Down
17 changes: 8 additions & 9 deletions tests/integration/models/reddit/test_subreddit.py
Original file line number Diff line number Diff line change
Expand Up @@ -694,7 +694,6 @@ def test_submit_video(self, _, __):
subreddit = self.reddit.subreddit(pytest.placeholders.test_subreddit)
for i, file_name in enumerate(("test.mov", "test.mp4")):
video = self.image_path(file_name)

submission = subreddit.submit_video(f"Test Title {i}", video)
assert submission.author == self.reddit.config.username
assert submission.is_video
Expand Down Expand Up @@ -878,10 +877,10 @@ def test_submit_video__videogif(self, _, __):
for file_name in ("test.mov", "test.mp4"):
video = self.image_path(file_name)

submission = subreddit.submit_video("Test Title", video, videogif=True)
assert submission.author == self.reddit.config.username
assert submission.is_video
assert submission.title == "Test Title"
message = "media_path and media_fp are null."
with pytest.raises(AssertionError) as excinfo:
subreddit.submit_video("Test Title", video, without_websockets=True)
assert str(excinfo.value) == message

@mock.patch("time.sleep", return_value=None)
def test_submit_video__without_websockets(self, _):
Expand All @@ -891,10 +890,10 @@ def test_submit_video__without_websockets(self, _):
for file_name in ("test.mov", "test.mp4"):
video = self.image_path(file_name)

submission = subreddit.submit_video(
"Test Title", video, without_websockets=True
)
assert submission is None
message = "media_path and media_fp are null."
with pytest.raises(AssertionError) as excinfo:
subreddit.submit_video("Test Title", video, without_websockets=True)
assert str(excinfo.value) == message

def test_subscribe(self):
self.reddit.read_only = False
Expand Down
38 changes: 33 additions & 5 deletions tests/unit/models/reddit/test_subreddit.py
Original file line number Diff line number Diff line change
Expand Up @@ -144,8 +144,16 @@ def test_submit_failure(self):
subreddit.submit("Cool title", selftext="", url="b")
assert str(excinfo.value) == message

def test_submit_gallery__missing_path(self):
message = "'image_path' is required."
def test_submit_image__invalid_image_fp(self):
message = "media_fp is not of type bytes."
subreddit = Subreddit(self.reddit, display_name="name")

with pytest.raises(TypeError) as excinfo:
subreddit.submit_image("Cool title", image_fp="invalid_image")
assert str(excinfo.value) == message

def test_submit_gallery__missing_image_path_and_image_fp(self):
message = "Values for keys image_path and image_fp are null for dictionary at index 0."
subreddit = Subreddit(self.reddit, display_name="name")

with pytest.raises(TypeError) as excinfo:
Expand All @@ -154,13 +162,33 @@ def test_submit_gallery__missing_path(self):
)
assert str(excinfo.value) == message

def test_submit_gallery__invalid_path(self):
message = "'invalid_image_path' is not a valid image path."
def test_submit_gallery__invalid_image_path(self):
message = "invalid_image is not a valid file path."
subreddit = Subreddit(self.reddit, display_name="name")

with pytest.raises(TypeError) as excinfo:
subreddit.submit_gallery("Cool title", [{"image_path": "invalid_image"}])
assert str(excinfo.value) == message

def test_submit_gallery__invalid_image_fp(self):
subreddit = Subreddit(self.reddit, display_name="name")

message = (
"'image_fp' dictionary value at index 0 contains an invalid bytes object."
)
with pytest.raises(TypeError) as excinfo:
subreddit.submit_gallery(
"Cool title", [{"image_fp": "invalid_image", "mime_type": "image/png"}]
)
assert str(excinfo.value) == message

encoded_string = "invalid_image".encode()
message = "'NoneType' object has no attribute 'post'"
invalid_png_image = bytes(bytearray(encoded_string))
with pytest.raises(AttributeError) as excinfo:
subreddit.submit_gallery(
"Cool title", [{"image_path": "invalid_image_path"}]
"Cool title",
[{"image_fp": invalid_png_image, "mime_type": "image/png"}],
)
assert str(excinfo.value) == message

Expand Down