From c77d318f39f689fa658df6ec383680fb314251fc Mon Sep 17 00:00:00 2001 From: Maximilian HUEBL Date: Wed, 28 Aug 2024 14:15:48 +0200 Subject: [PATCH 1/5] lazy evaluation of rational powers --- src/types.jl | 30 +++++++++++++++++++++++++----- test/rulesets.jl | 8 ++++++++ 2 files changed, 33 insertions(+), 5 deletions(-) diff --git a/src/types.jl b/src/types.jl index 6dc2320e9..11bad85e3 100644 --- a/src/types.jl +++ b/src/types.jl @@ -608,9 +608,17 @@ function basicsymbolic(f, args, stype, metadata) end res elseif f == (^) && length(args) == 2 - res = args[1] ^ args[2] - if ispow(res) - @set! res.metadata = metadata + if args[2] isa Rational && !(args[1] isa Symbolic) + if !isinteger(args[2]) + @goto FALLBACK + end + integer_type = only(typeof(args[2]).parameters) + res = args[1] ^ convert(integer_type, args[2]) + else + res = args[1] ^ args[2] + if ispow(res) + @set! res.metadata = metadata + end end res else @@ -1210,9 +1218,21 @@ function ^(a::SN, b) elseif b isa Number && b < 0 Div(1, a ^ (-b)) elseif ismul(a) && b isa Number - coeff = unstable_pow(a.coeff, b) + new_dict = mapvalues((k, v) -> b*v, a.dict) + if b isa Rational + if isinteger(b) + integer_type = only(typeof(b).parameters) + coeff = a.coeff ^ convert(integer_type, b) + else + coeff = 1 + merge!(new_dict, Dict(term(^, a.coeff, b) => 1)) + end + else + coeff = unstable_pow(a.coeff, b) + end + Mul(promote_symtype(^, symtype(a), symtype(b)), - coeff, mapvalues((k, v) -> b*v, a.dict)) + coeff, new_dict) else Pow(a, b) end diff --git a/test/rulesets.jl b/test/rulesets.jl index 4730cc102..aff1ecd3f 100644 --- a/test/rulesets.jl +++ b/test/rulesets.jl @@ -52,6 +52,14 @@ end @test simplify(Term(zero, [a])) == 0 @test simplify(Term(zero, [b + 1])) == 0 @test simplify(Term(zero, [x + 2])) == 0 + + @eqtest simplify(Term(sqrt, [2])) == Term(sqrt, [2]) + @eqtest simplify(Term(^, [2, 1//2])) == Term(^, [2, 1//2]) + @eqtest simplify(Term(^, [2x, 1//2])) == Term(^, [2, 1//2]) * x^(1//2) + @test simplify(Term(^, [2, 3])) ≈ 8 + @test simplify(Term(^, [1//3, 3])) == 1//27 + @test simplify(Term(^, [2, 0.5])) ≈ 2^0.5 + @test simplify(Term(^, [2.5, 0.25])) ≈ 2.5^(0.25) end @testset "LiteralReal" begin From 2daeb7ec7c5a889a0ca10879177ef41e5ca8d731 Mon Sep 17 00:00:00 2001 From: Maximilian HUEBL Date: Thu, 29 Aug 2024 17:23:42 +0200 Subject: [PATCH 2/5] add pow tests --- test/basics.jl | 18 ++++++++++++++++++ 1 file changed, 18 insertions(+) diff --git a/test/basics.jl b/test/basics.jl index f40da5e2d..cc3276d5b 100644 --- a/test/basics.jl +++ b/test/basics.jl @@ -357,3 +357,21 @@ end end @test repr(sin(x) + sin(x)) == "sin(x) + sin(x)" end + +@testset "pow" begin + @syms x y + @eqtest (2x)^(1//2) == term(^, 2, 1//2) * x^(1//2) + @eqtest (2x)^(1//1) == 2x + # @eqtest (2x)^1 == 2x ## This currently fails and returns (2//1)*x + @eqtest (2x)^(2//1) == 4x^(2//1) + @eqtest ((1//3)*x)^(1//4) == term(^, 1//3, 1//4) * x^(1//4) + + @eqtest (x+y)^(1//2) == (x+y)^(1//2) + @eqtest (x+y)^(1//1) == x+y + + @eqtest (x^(3//1))^(1//3) == x + @eqtest (x^(3//1))^(2//3) == x^(2//1) + @eqtest (x^(2//1))^2 == x^(4//1) + + @test ((2x)^0.5).coeff ≈ sqrt(2) +end From f54464fc216a96953dc17ba49021b5de44c02113 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Maximilian=20H=C3=BCbl?= <89022871+mxhbl@users.noreply.github.com> Date: Fri, 30 Aug 2024 09:23:13 +0200 Subject: [PATCH 3/5] Apply suggestions from code review Co-authored-by: Bowen S. Zhu --- src/types.jl | 14 +++++++------- 1 file changed, 7 insertions(+), 7 deletions(-) diff --git a/src/types.jl b/src/types.jl index 11bad85e3..e2692dc5b 100644 --- a/src/types.jl +++ b/src/types.jl @@ -608,17 +608,17 @@ function basicsymbolic(f, args, stype, metadata) end res elseif f == (^) && length(args) == 2 - if args[2] isa Rational && !(args[1] isa Symbolic) - if !isinteger(args[2]) + if args[2] isa Rational && !isa(args[1], Symbolic) + if isinteger(args[2]) + res = args[1] ^ numerator(args[2]) + else @goto FALLBACK end - integer_type = only(typeof(args[2]).parameters) - res = args[1] ^ convert(integer_type, args[2]) else res = args[1] ^ args[2] - if ispow(res) - @set! res.metadata = metadata - end + end + if ispow(res) + @set! res.metadata = metadata end res else From e06f2c78ec761e7a39771095458c3bbf88d2ab22 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Maximilian=20H=C3=BCbl?= <89022871+mxhbl@users.noreply.github.com> Date: Fri, 30 Aug 2024 09:52:20 +0200 Subject: [PATCH 4/5] Update src/types.jl Co-authored-by: Bowen S. Zhu --- src/types.jl | 3 +-- 1 file changed, 1 insertion(+), 2 deletions(-) diff --git a/src/types.jl b/src/types.jl index e2692dc5b..cefc7e468 100644 --- a/src/types.jl +++ b/src/types.jl @@ -1221,8 +1221,7 @@ function ^(a::SN, b) new_dict = mapvalues((k, v) -> b*v, a.dict) if b isa Rational if isinteger(b) - integer_type = only(typeof(b).parameters) - coeff = a.coeff ^ convert(integer_type, b) + coeff = a.coeff ^ numerator(b) else coeff = 1 merge!(new_dict, Dict(term(^, a.coeff, b) => 1)) From 11fc4afd7ca70c36c1eaff7fb30997683ce99fb1 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Maximilian=20H=C3=BCbl?= Date: Tue, 10 Sep 2024 09:13:15 +0200 Subject: [PATCH 5/5] code review --- src/types.jl | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/types.jl b/src/types.jl index cefc7e468..99723863e 100644 --- a/src/types.jl +++ b/src/types.jl @@ -1224,7 +1224,7 @@ function ^(a::SN, b) coeff = a.coeff ^ numerator(b) else coeff = 1 - merge!(new_dict, Dict(term(^, a.coeff, b) => 1)) + new_dict[term(^, a.coeff, b)] = 1 end else coeff = unstable_pow(a.coeff, b)