diff --git a/.rubocop.yml b/.rubocop.yml index 947d23cc..9a80d12d 100644 --- a/.rubocop.yml +++ b/.rubocop.yml @@ -59,3 +59,9 @@ Style/NegatedIfElseCondition: Exclude: - 'lib/ronin/exploits/sqli.rb' - 'lib/ronin/exploits/mixins/html.rb' + +# we need to define a #=== operator for matching payload class(es) +Style/CaseEquality: + Exclude: + - 'lib/ronin/exploits/mixins/has_payload.rb' + - 'spec/mixins/has_payload_spec.rb' diff --git a/lib/ronin/exploits/mixins/has_payload.rb b/lib/ronin/exploits/mixins/has_payload.rb index 93280661..6318286f 100644 --- a/lib/ronin/exploits/mixins/has_payload.rb +++ b/lib/ronin/exploits/mixins/has_payload.rb @@ -58,6 +58,57 @@ def self.included(exploit) exploit.extend ClassMethods end + # + # Helper class for matching a payload against multiple accepted payload + # classes. + # + # @api private + # + # @since 1.3.0 + # + class PayloadClasses + + # The payload classes to match against. + # + # @return [Array>] + attr_reader :classes + + # + # Initializes the payload classes. + # + # @param [Array>] classes + # The payload classes. + # + def initialize(classes) + @classes = classes + end + + # + # Compares the payload object to the payload classes. + # + # @param [Ronin::Payloads::Payload] payload + # The payload object to match. + # + # @return [Boolean] + # Indicates whether the payload object inherits from any of the + # payload classes. + # + def ===(payload) + @classes.any? { |payload_class| payload_class === payload } + end + + # + # Converts the list payload classes to a String. + # + # @return [String] + # The comma separated list of payload class names. + # + def to_s + @classes.join(', ') + end + + end + # # Class methods. # @@ -72,9 +123,13 @@ module ClassMethods # @return [Class] # The exploit's compatible payload base class. # - def payload_class(new_payload_class=nil) - if new_payload_class - @payload_class = new_payload_class + def payload_class(*new_payload_classes) + unless new_payload_classes.empty? + @payload_class = if new_payload_classes.length == 1 + new_payload_classes.first + else + PayloadClasses.new(new_payload_classes) + end else @payload_class ||= if superclass.kind_of?(ClassMethods) superclass.payload_class @@ -113,7 +168,7 @@ def initialize(payload: nil, **kwargs) # def payload=(new_payload) if new_payload.kind_of?(Payloads::Payload) - unless new_payload.kind_of?(self.class.payload_class) + unless self.class.payload_class === new_payload raise(IncompatiblePayload,"incompatible payload, must be a #{self.class.payload_class} payload: #{new_payload.inspect}") end end diff --git a/spec/mixins/has_payload_spec.rb b/spec/mixins/has_payload_spec.rb index cb052160..bdb1e507 100644 --- a/spec/mixins/has_payload_spec.rb +++ b/spec/mixins/has_payload_spec.rb @@ -8,6 +8,9 @@ module TestHasPayload class TestPayload < Ronin::Payloads::Payload end + class TestPayload2 < Ronin::Payloads::Payload + end + class InheritedPayload < TestPayload end @@ -24,6 +27,12 @@ class WithPayloadClass < Ronin::Exploits::Exploit payload_class TestPayload end + class WithPayloadClasses < Ronin::Exploits::Exploit + include Ronin::Exploits::Mixins::HasPayload + + payload_class TestPayload, TestPayload2 + end + class InheritesPayloadClass < WithPayloadClass end @@ -32,6 +41,46 @@ class InheritesAndOverridesPayloadClass < WithPayloadClass end end + describe described_class::PayloadClasses do + let(:classes) do + [TestHasPayload::TestPayload, TestHasPayload::TestPayload2] + end + + subject { described_class.new(classes) } + + describe "#initialize" do + it "must set #classes" do + expect(subject.classes).to eq(classes) + end + end + + describe "#===" do + context "when the given payload object is kind of one of the payload classes" do + let(:payload) { TestHasPayload::TestPayload2.new } + + it "must return true" do + expect(subject === payload).to be(true) + end + end + + context "when the given payload object is not kind of any of the payload classes" do + let(:payload) { TestHasPayload::TestOtherPayload.new } + + it "must return false" do + expect(subject === payload).to be(false) + end + end + end + + describe "#to_s" do + it "must return a comma-separated String of the class names" do + expect(subject.to_s).to eq( + "#{TestHasPayload::TestPayload}, #{TestHasPayload::TestPayload2}" + ) + end + end + end + describe ".payload_class" do subject { test_class } @@ -44,10 +93,23 @@ class InheritesAndOverridesPayloadClass < WithPayloadClass end context "when the payload_class has been set in the Exploit class" do - let(:test_class) { TestHasPayload::WithPayloadClass } + context "with a single payload class" do + let(:test_class) { TestHasPayload::WithPayloadClass } - it "must set the payload_class to the given payload class" do - expect(subject.payload_class).to be(TestHasPayload::TestPayload) + it "must set the payload_class to the given payload class" do + expect(subject.payload_class).to be(TestHasPayload::TestPayload) + end + end + + context "with multiple payload classes" do + let(:test_class) { TestHasPayload::WithPayloadClasses } + + it "must set payload_class to a #{described_class}::PayloadClasses objects with the given payload classes" do + expect(subject.payload_class).to be_kind_of(described_class::PayloadClasses) + expect(subject.payload_class.classes).to eq( + [TestHasPayload::TestPayload, TestHasPayload::TestPayload2] + ) + end end end @@ -156,35 +218,71 @@ class InheritesAndOverridesPayloadClass < WithPayloadClass end context "but the Exploit has defined a payload_class" do - let(:test_class) { TestHasPayload::WithPayloadClass } + context "with a single payload class" do + let(:test_class) { TestHasPayload::WithPayloadClass } - context "and the given payload object is a kind of payload_class" do - let(:payload) { test_class.payload_class.new } + context "and the given payload object is a kind of payload_class" do + let(:payload) { test_class.payload_class.new } - before { subject.payload = payload } + before { subject.payload = payload } - it "must set #payload" do - expect(subject.payload).to be(payload) + it "must set #payload" do + expect(subject.payload).to be(payload) + end end - end - context "and the given payload object inherits from payload_class" do - let(:payload) { TestHasPayload::InheritedPayload.new } + context "and the given payload object inherits from payload_class" do + let(:payload) { TestHasPayload::InheritedPayload.new } - before { subject.payload = payload } + before { subject.payload = payload } - it "must set #payload" do - expect(subject.payload).to be(payload) + it "must set #payload" do + expect(subject.payload).to be(payload) + end + end + + context "but the given payload is not a kind of payload_class" do + let(:payload) { TestHasPayload::TestOtherPayload.new } + + it do + expect { + subject.payload = payload + }.to raise_error(Ronin::Exploits::IncompatiblePayload,"incompatible payload, must be a #{test_class.payload_class} payload: #{payload.inspect}") + end end end - context "but the given payload is not a kind of payload_class" do - let(:payload) { TestHasPayload::TestOtherPayload.new } + context "with multiple payload classes" do + let(:test_class) { TestHasPayload::WithPayloadClasses } + + context "and the given payload object is a kind of payload_class" do + let(:payload) { test_class.payload_class.classes.last.new } + + before { subject.payload = payload } + + it "must set #payload" do + expect(subject.payload).to be(payload) + end + end + + context "and the given payload object inherits from payload_class" do + let(:payload) { TestHasPayload::InheritedPayload.new } + + before { subject.payload = payload } + + it "must set #payload" do + expect(subject.payload).to be(payload) + end + end + + context "but the given payload is not a kind of payload_class" do + let(:payload) { TestHasPayload::TestOtherPayload.new } - it do - expect { - subject.payload = payload - }.to raise_error(Ronin::Exploits::IncompatiblePayload,"incompatible payload, must be a #{test_class.payload_class} payload: #{payload.inspect}") + it do + expect { + subject.payload = payload + }.to raise_error(Ronin::Exploits::IncompatiblePayload,"incompatible payload, must be a #{test_class.payload_class} payload: #{payload.inspect}") + end end end end