Skip to content

Commit

Permalink
Add filter for hashtags or membership; Add sponsored post to membersh…
Browse files Browse the repository at this point in the history
…ip feeds
  • Loading branch information
rudyfraser committed Nov 25, 2024
1 parent c421d35 commit d1ba3c7
Show file tree
Hide file tree
Showing 3 changed files with 45 additions and 185 deletions.
2 changes: 1 addition & 1 deletion rsky-feedgen/Cargo.toml
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
[package]
name = "rsky-feedgen"
version = "0.4.1"
version = "1.0.0"
authors = ["Rudy Fraser <[email protected]>"]
description = "A framework for building AT Protocol feed generators, in Rust."
license = "Apache-2.0"
Expand Down
142 changes: 41 additions & 101 deletions rsky-feedgen/src/apis/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -22,21 +22,29 @@ pub async fn get_posts_by_membership(
params_cursor: Option<&str>,
only_posts: bool,
list: String,
hashtags: Vec<String>,
connection: ReadReplicaConn,
config: &State<FeedGenConfig>,
) -> Result<AlgoResponse, ValidationErrorMessageResponse> {
use crate::schema::membership::dsl as MembershipSchema;
use crate::schema::post::dsl as PostSchema;
use diesel::dsl::any;

let show_sponsored_post = config.show_sponsored_post.clone();
let sponsored_post_uri = config.sponsored_post_uri.clone();
let sponsored_post_probability = config.sponsored_post_probability.clone();

let params_cursor = match params_cursor {
None => None,
Some(params_cursor) => Some(params_cursor.to_string()),
};
let result = connection
.run(move |conn| {
let mut query = PostSchema::post
.inner_join(
.left_join(
MembershipSchema::membership.on(PostSchema::author
.eq(MembershipSchema::did)
.and(MembershipSchema::list.eq(list))
.and(MembershipSchema::list.eq(list.clone()))
.and(MembershipSchema::included.eq(true))),
)
.limit(limit.unwrap_or(30))
Expand All @@ -48,8 +56,7 @@ pub async fn get_posts_by_membership(
query = query.filter(PostSchema::lang.like(format!("%{}%", lang)));
}

if params_cursor.is_some() {
let cursor_str = params_cursor.unwrap();
if let Some(cursor_str) = params_cursor {
let v = cursor_str
.split("::")
.take(2)
Expand Down Expand Up @@ -89,6 +96,23 @@ pub async fn get_posts_by_membership(
.filter(PostSchema::replyParent.is_null())
.filter(PostSchema::replyRoot.is_null());
}

// Adjust the filtering logic
if hashtags.is_empty() {
// No hashtags provided, include only posts where author is in the list
query = query.filter(MembershipSchema::did.is_not_null());
} else {
let hashtag_patterns: Vec<String> = hashtags
.iter()
.map(|hashtag| format!("%{}%", hashtag))
.collect();
query = query.filter(
MembershipSchema::did
.is_not_null()
.or(PostSchema::text.ilike(any(hashtag_patterns))),
);
}

let results = query.load(conn).expect("Error loading post records");

let mut post_results = Vec::new();
Expand All @@ -114,108 +138,24 @@ pub async fn get_posts_by_membership(
})
.for_each(drop);

let new_response = AlgoResponse {
cursor,
feed: post_results,
};
Ok(new_response)
})
.await;

result
}

#[allow(deprecated)]
pub async fn get_blacksky_nsfw(
limit: Option<i64>,
params_cursor: Option<&str>,
connection: ReadReplicaConn,
) -> Result<AlgoResponse, ValidationErrorMessageResponse> {
use crate::schema::image::dsl as ImageSchema;
use crate::schema::post::dsl as PostSchema;
let params_cursor = match params_cursor {
None => None,
Some(params_cursor) => Some(params_cursor.to_string()),
};
let result = connection
.run(move |conn| {
let mut query = PostSchema::post
.limit(limit.unwrap_or(30))
.select(Post::as_select())
.order((PostSchema::createdAt.desc(), PostSchema::cid.desc()))
.into_boxed();
// Insert the sponsored post if the conditions are met
if show_sponsored_post && post_results.len() >= 3 && !sponsored_post_uri.is_empty() {
// Generate a random chance to include the sponsored post based on probability
let mut rng = rand::thread_rng();
let random_chance: f64 = rng.gen();

query = query.filter(
PostSchema::cid.eq_any(
ImageSchema::image
.filter(ImageSchema::labels.contains(vec!["sexy"]))
.filter(ImageSchema::alt.is_not_null())
.select(ImageSchema::postCid),
),
);
// Only include the sponsored post if random chance is below the specified probability
if random_chance < sponsored_post_probability {
// Generate a random index to insert the sponsored post (ensure it's not the last position)
let replace_index = rng.gen_range(0..(post_results.len() - 1));

if params_cursor.is_some() {
let cursor_str = params_cursor.unwrap();
let v = cursor_str
.split("::")
.take(2)
.map(String::from)
.collect::<Vec<_>>();
if let [created_at_c, cid_c] = &v[..] {
if let Ok(timestamp) = created_at_c.parse::<i64>() {
let nanoseconds = 230 * 1000000;
let datetime = DateTime::<Utc>::from_utc(
NaiveDateTime::from_timestamp(timestamp / 1000, nanoseconds),
Utc,
);
let mut timestr = String::new();
match write!(timestr, "{}", datetime.format("%+")) {
Ok(_) => {
query = query.filter(
PostSchema::createdAt.lt(timestr.to_owned()).or(
PostSchema::createdAt
.eq(timestr.to_owned())
.and(PostSchema::cid.lt(cid_c.to_owned())),
),
);
}
Err(error) => eprintln!("Error formatting: {error:?}"),
}
}
} else {
let validation_error = ValidationErrorMessageResponse {
code: Some(ErrorCode::ValidationError),
message: Some("malformed cursor".into()),
// Replace a random post with the sponsored post
post_results[replace_index] = PostResult {
post: sponsored_post_uri.clone(),
};
return Err(validation_error);
}
}

let results = query.load(conn).expect("Error loading post records");

let mut post_results = Vec::new();
let mut cursor: Option<String> = None;

// https://docs.rs/chrono/0.4.26/chrono/format/strftime/index.html
if let Some(last_post) = results.last() {
if let Ok(parsed_time) = NaiveDateTime::parse_from_str(&last_post.created_at, "%+")
{
cursor = Some(format!(
"{}::{}",
parsed_time.timestamp_millis(),
last_post.cid
));
}
}

results
.into_iter()
.map(|result| {
let post_result = PostResult { post: result.uri };
post_results.push(post_result);
})
.for_each(drop);

let new_response = AlgoResponse {
cursor,
feed: post_results,
Expand Down Expand Up @@ -757,8 +697,8 @@ pub async fn queue_creation(
hashtags.contains("#blacktechsky") ||
hashtags.contains("#nbablacksky") ||
hashtags.contains("#addtoblacksky") ||
hashtags.contains("#blackademics") ||
hashtags.contains("#addtoblackskytravel") ||
hashtags.contains("#skytravel") ||
hashtags.contains("#blackskytravel")) &&
!is_blocked &&
!hashtags.contains("#private") &&
Expand Down
86 changes: 3 additions & 83 deletions rsky-feedgen/src/routes.rs
Original file line number Diff line number Diff line change
Expand Up @@ -84,12 +84,6 @@ pub(crate) const BLACKSKY_OG: &str =
"at://did:plc:w4xbfzo7kqfes5zb7r6qv3rw/app.bsky.feed.generator/blacksky-op";
pub(crate) const BLACKSKY_TREND: &str =
"at://did:plc:w4xbfzo7kqfes5zb7r6qv3rw/app.bsky.feed.generator/blacksky-trend";
pub(crate) const BLACKSKY_FR: &str =
"at://did:plc:w4xbfzo7kqfes5zb7r6qv3rw/app.bsky.feed.generator/blacksky-fr";
pub(crate) const BLACKSKY_PT: &str =
"at://did:plc:w4xbfzo7kqfes5zb7r6qv3rw/app.bsky.feed.generator/blacksky-pt";
pub(crate) const BLACKSKY_NSFW: &str =
"at://did:plc:w4xbfzo7kqfes5zb7r6qv3rw/app.bsky.feed.generator/blacksky-nsfw";
pub(crate) const BLACKSKY_EDU: &str =
"at://did:plc:w4xbfzo7kqfes5zb7r6qv3rw/app.bsky.feed.generator/blacksky-edu";
pub(crate) const BLACKSKY_TRAVEL: &str =
Expand Down Expand Up @@ -208,7 +202,9 @@ pub async fn index(
cursor,
true,
"blacksky-edu".into(),
vec!["#blackademics".into()],
connection,
config,
)
.await
{
Expand All @@ -233,55 +229,7 @@ pub async fn index(
cursor,
true,
"blacksky-travel".into(),
connection,
)
.await
{
Ok(response) => Ok(Json(response)),
Err(error) => {
eprintln!("Internal Error: {error}");
let internal_error = crate::models::InternalErrorMessageResponse {
code: Some(crate::models::InternalErrorCode::InternalError),
message: Some(error.to_string()),
};
Err(status::Custom(
Status::InternalServerError,
Json(internal_error),
))
}
}
}
_blacksky_fr if _blacksky_fr == BLACKSKY_FR && !is_banned => {
match crate::apis::get_all_posts(
Some("fr".into()),
limit,
cursor,
true,
connection,
config,
)
.await
{
Ok(response) => Ok(Json(response)),
Err(error) => {
eprintln!("Internal Error: {error}");
let internal_error = crate::models::InternalErrorMessageResponse {
code: Some(crate::models::InternalErrorCode::InternalError),
message: Some(error.to_string()),
};
Err(status::Custom(
Status::InternalServerError,
Json(internal_error),
))
}
}
}
_blacksky_pt if _blacksky_pt == BLACKSKY_PT && !is_banned => {
match crate::apis::get_all_posts(
Some("pt".into()),
limit,
cursor,
true,
vec!["blackskytravel".into()],
connection,
config,
)
Expand All @@ -301,22 +249,6 @@ pub async fn index(
}
}
}
_blacksky_nsfw if _blacksky_nsfw == BLACKSKY_NSFW && !is_banned => {
match crate::apis::get_blacksky_nsfw(limit, cursor, connection).await {
Ok(response) => Ok(Json(response)),
Err(error) => {
eprintln!("Internal Error: {error}");
let internal_error = crate::models::InternalErrorMessageResponse {
code: Some(crate::models::InternalErrorCode::InternalError),
message: Some(error.to_string()),
};
Err(status::Custom(
Status::InternalServerError,
Json(internal_error),
))
}
}
}
_blacksky if _blacksky == BLACKSKY && is_banned => {
let banned_response = get_banned_response();
Ok(Json(banned_response))
Expand All @@ -329,18 +261,6 @@ pub async fn index(
let banned_response = get_banned_response();
Ok(Json(banned_response))
}
_blacksky_fr if _blacksky_fr == BLACKSKY_FR && is_banned => {
let banned_response = get_banned_response();
Ok(Json(banned_response))
}
_blacksky_pt if _blacksky_pt == BLACKSKY_PT && is_banned => {
let banned_response = get_banned_response();
Ok(Json(banned_response))
}
_blacksky_nsfw if _blacksky_nsfw == BLACKSKY_NSFW && is_banned => {
let banned_response = get_banned_response();
Ok(Json(banned_response))
}
_blacksky_edu if _blacksky_edu == BLACKSKY_EDU && is_banned => {
let banned_response = get_banned_response();
Ok(Json(banned_response))
Expand Down

0 comments on commit d1ba3c7

Please sign in to comment.