diff --git a/prawtools/stats.py b/prawtools/stats.py index e9f0038..530c7ae 100644 --- a/prawtools/stats.py +++ b/prawtools/stats.py @@ -105,7 +105,7 @@ def _save_report(title, body): def _user(user): return "_deleted_" if user is None else tt("/u/{}").format(user) - def __init__(self, subreddit, site, distinguished, reddit=None): + def __init__(self, subreddit, site, distinguished, output_subreddit, reddit=None): """Initialize the SubredditStats instance with config options.""" self.commenters = defaultdict(list) self.comments = [] @@ -115,7 +115,7 @@ def __init__(self, subreddit, site, distinguished, reddit=None): self.reddit = reddit or Reddit(site, check_for_updates=False, user_agent=AGENT) self.submissions = {} self.submitters = defaultdict(list) - self.submit_subreddit = self.reddit.subreddit("subreddit_stats") + self.submit_subreddit = self.reddit.subreddit(output_subreddit) self.subreddit = self.reddit.subreddit(subreddit) def basic_stats(self): @@ -443,6 +443,13 @@ def main(): default=10, help="Number of top submitters to display " "[default %default]", ) + parser.add_option( + "-o", + "--output", + type="string", + default="subreddit_stats", + help="Subreddit to publish results to " "[default %default]", + ) options, args = parser.parse_args() @@ -458,7 +465,7 @@ def main(): parser.error("SUBREDDIT and VIEW must be provided") subreddit, view = args check_for_updates(options) - srs = SubredditStats(subreddit, options.site, options.distinguished) + srs = SubredditStats(subreddit, options.site, options.distinguished, options.output) result = srs.run(view, options.submitters, options.commenters) if result: print(result.permalink) diff --git a/tests/test_stats.py b/tests/test_stats.py index f54f84a..516c28b 100644 --- a/tests/test_stats.py +++ b/tests/test_stats.py @@ -8,7 +8,7 @@ class StatsTest(IntegrationTest): def setUp(self): """Setup runs before all test cases.""" - self.srs = SubredditStats("redditdev", None, None) + self.srs = SubredditStats("redditdev", None, None, None) super(StatsTest, self).setUp(self.srs.reddit._core._requestor._http) def test_recent(self):