diff --git a/src/types.jl b/src/types.jl index 6dc2320e..99723863 100644 --- a/src/types.jl +++ b/src/types.jl @@ -608,7 +608,15 @@ function basicsymbolic(f, args, stype, metadata) end res elseif f == (^) && length(args) == 2 - res = args[1] ^ args[2] + if args[2] isa Rational && !isa(args[1], Symbolic) + if isinteger(args[2]) + res = args[1] ^ numerator(args[2]) + else + @goto FALLBACK + end + else + res = args[1] ^ args[2] + end if ispow(res) @set! res.metadata = metadata end @@ -1210,9 +1218,20 @@ function ^(a::SN, b) elseif b isa Number && b < 0 Div(1, a ^ (-b)) elseif ismul(a) && b isa Number - coeff = unstable_pow(a.coeff, b) + new_dict = mapvalues((k, v) -> b*v, a.dict) + if b isa Rational + if isinteger(b) + coeff = a.coeff ^ numerator(b) + else + coeff = 1 + new_dict[term(^, a.coeff, b)] = 1 + end + else + coeff = unstable_pow(a.coeff, b) + end + Mul(promote_symtype(^, symtype(a), symtype(b)), - coeff, mapvalues((k, v) -> b*v, a.dict)) + coeff, new_dict) else Pow(a, b) end diff --git a/test/basics.jl b/test/basics.jl index 7c29162e..08bbb624 100644 --- a/test/basics.jl +++ b/test/basics.jl @@ -363,5 +363,23 @@ end ax = adjoint(x) @test isequal(ax, x) @test ax === x - @test isequal(adjoint(y), conj(y)) + @test isequal(adjoint(y), conj(y)) +end + +@testset "pow" begin + @syms x y + @eqtest (2x)^(1//2) == term(^, 2, 1//2) * x^(1//2) + @eqtest (2x)^(1//1) == 2x + # @eqtest (2x)^1 == 2x ## This currently fails and returns (2//1)*x + @eqtest (2x)^(2//1) == 4x^(2//1) + @eqtest ((1//3)*x)^(1//4) == term(^, 1//3, 1//4) * x^(1//4) + + @eqtest (x+y)^(1//2) == (x+y)^(1//2) + @eqtest (x+y)^(1//1) == x+y + + @eqtest (x^(3//1))^(1//3) == x + @eqtest (x^(3//1))^(2//3) == x^(2//1) + @eqtest (x^(2//1))^2 == x^(4//1) + + @test ((2x)^0.5).coeff ≈ sqrt(2) end diff --git a/test/rulesets.jl b/test/rulesets.jl index 4730cc10..aff1ecd3 100644 --- a/test/rulesets.jl +++ b/test/rulesets.jl @@ -52,6 +52,14 @@ end @test simplify(Term(zero, [a])) == 0 @test simplify(Term(zero, [b + 1])) == 0 @test simplify(Term(zero, [x + 2])) == 0 + + @eqtest simplify(Term(sqrt, [2])) == Term(sqrt, [2]) + @eqtest simplify(Term(^, [2, 1//2])) == Term(^, [2, 1//2]) + @eqtest simplify(Term(^, [2x, 1//2])) == Term(^, [2, 1//2]) * x^(1//2) + @test simplify(Term(^, [2, 3])) ≈ 8 + @test simplify(Term(^, [1//3, 3])) == 1//27 + @test simplify(Term(^, [2, 0.5])) ≈ 2^0.5 + @test simplify(Term(^, [2.5, 0.25])) ≈ 2.5^(0.25) end @testset "LiteralReal" begin