Skip to content

Automatic JuliaFormatter.jl run #124

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Open
wants to merge 1 commit into
base: master
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
16 changes: 8 additions & 8 deletions benchmark/benchmarks.jl
Original file line number Diff line number Diff line change
Expand Up @@ -236,14 +236,14 @@ function benchmark_utilities()
[get_set_constants!(ex) for ex in exs],
seconds = 10.0,
setup = (
operators = $operators;
ntrees = 100;
n = 20;
n_features = 5;
n_params = 3;
n_param_classes = 10;
rng = Random.MersenneTwister(0);
exs = [
operators=($operators);
ntrees=100;
n=20;
n_features=5;
n_params=3;
n_param_classes=10;
rng=Random.MersenneTwister(0);
exs=[
let tree = gen_random_tree_fixed_size(
n, operators, n_features, Float32, ParametricNode, rng
)
Expand Down
4 changes: 2 additions & 2 deletions ext/DynamicExpressionsSymbolicUtilsExt.jl
Original file line number Diff line number Diff line change
Expand Up @@ -90,7 +90,7 @@ function split_eq(
op,
args,
operators::AbstractOperatorEnum,
::Type{N}=Node;
(::Type{N})=Node;
variable_names::Union{AbstractVector{<:AbstractString},Nothing}=nothing,
# Deprecated:
varMap=nothing,
Expand Down Expand Up @@ -255,7 +255,7 @@ end
function symbolic_to_node(
eqn::SymbolicUtils.Symbolic,
operators::AbstractOperatorEnum,
::Type{N}=Node;
(::Type{N})=Node;
variable_names::Union{AbstractVector{<:AbstractString},Nothing}=nothing,
# Deprecated:
varMap=nothing,
Expand Down
2 changes: 1 addition & 1 deletion ext/DynamicExpressionsZygoteExt.jl
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,7 @@ import DynamicExpressions.ExtensionInterfaceModule: _zygote_gradient, ZygoteGrad
function _zygote_gradient(op::F, ::Val{1}) where {F}
return ZygoteGradient{F,1,1}(op)
end
function _zygote_gradient(op::F, ::Val{2}, ::Val{side}=Val(nothing)) where {F,side}
function _zygote_gradient(op::F, ::Val{2}, (::Val{side})=Val(nothing)) where {F,side}
# side should be either nothing (for both), 1, or 2
@assert side === nothing || side in (1, 2)
return ZygoteGradient{F,2,side}(op)
Expand Down
3 changes: 2 additions & 1 deletion src/EvaluationHelpers.jl
Original file line number Diff line number Diff line change
Expand Up @@ -86,7 +86,8 @@ to every constant in the expression.
- `(evaluation, gradient, complete)::Tuple{AbstractVector{T}, AbstractMatrix{T}, Bool}`: the normal evaluation,
the gradient, and whether the evaluation completed as normal (or encountered a nan or inf).
"""
Base.adjoint(tree::AbstractExpressionNode) =
function Base.adjoint(tree::AbstractExpressionNode)
((args...; kws...) -> _grad_evaluator(tree, args...; kws...))
end

end
2 changes: 1 addition & 1 deletion src/Expression.jl
Original file line number Diff line number Diff line change
Expand Up @@ -520,7 +520,7 @@ end
function copy_into!(::Nothing, src::AbstractExpression)
return copy(src)
end
function allocate_container(::AbstractExpression, ::Union{Nothing,Integer}=nothing)
function allocate_container(::AbstractExpression, (::Union{Nothing,Integer})=nothing)
return nothing
end
# COV_EXCL_STOP
Expand Down
46 changes: 21 additions & 25 deletions src/ExpressionAlgebra.jl
Original file line number Diff line number Diff line change
Expand Up @@ -107,32 +107,28 @@ the operator is unary (1) or binary (2).
macro declare_expression_operator(op, arity)
@assert arity ∈ (1, 2)
if arity == 1
return esc(
quote
$op(l::AbstractExpression) = $(apply_operator)($op, l)
end,
)
return esc(quote
$op(l::AbstractExpression) = $(apply_operator)($op, l)
end)
elseif arity == 2
return esc(
quote
function $op(l::AbstractExpression, r::AbstractExpression)
return $(apply_operator)($op, l, r)
end
function $op(l::T, r::AbstractExpression{T}) where {T}
return $(apply_operator)($op, l, r)
end
function $op(l::AbstractExpression{T}, r::T) where {T}
return $(apply_operator)($op, l, r)
end
# Convenience methods for Number types
function $op(l::Number, r::AbstractExpression{T}) where {T}
return $(apply_operator)($op, l, r)
end
function $op(l::AbstractExpression{T}, r::Number) where {T}
return $(apply_operator)($op, l, r)
end
end,
)
return esc(quote
function $op(l::AbstractExpression, r::AbstractExpression)
return $(apply_operator)($op, l, r)
end
function $op(l::T, r::AbstractExpression{T}) where {T}
return $(apply_operator)($op, l, r)
end
function $op(l::AbstractExpression{T}, r::T) where {T}
return $(apply_operator)($op, l, r)
end
# Convenience methods for Number types
function $op(l::Number, r::AbstractExpression{T}) where {T}
return $(apply_operator)($op, l, r)
end
function $op(l::AbstractExpression{T}, r::Number) where {T}
return $(apply_operator)($op, l, r)
end
end)
end
end

Expand Down
4 changes: 2 additions & 2 deletions src/NodeUtils.jl
Original file line number Diff line number Diff line change
Expand Up @@ -94,7 +94,7 @@ given the output of this function.
Also return metadata that can will be used in the `set_scalar_constants!` function.
"""
function get_scalar_constants(
tree::AbstractExpressionNode{T}, ::Type{BT}=get_number_type(T)
tree::AbstractExpressionNode{T}, (::Type{BT})=get_number_type(T)
) where {T,BT}
refs = filter_map(
is_node_constant, node -> Ref(node), tree, Base.RefValue{typeof(tree)}
Expand Down Expand Up @@ -160,7 +160,7 @@ end
# as we trace over the node we are indexing on.
preserve_sharing(::Union{Type{<:NodeIndex},NodeIndex}) = false

function index_constant_nodes(tree::AbstractExpressionNode, ::Type{T}=UInt16) where {T}
function index_constant_nodes(tree::AbstractExpressionNode, (::Type{T})=UInt16) where {T}
# Essentially we copy the tree, replacing the values
# with indices
constant_index = Ref(T(0))
Expand Down
28 changes: 12 additions & 16 deletions src/OperatorEnumConstruction.jl
Original file line number Diff line number Diff line change
Expand Up @@ -378,14 +378,12 @@ defined.
macro extend_operators(operators, kws...)
ex = _extend_operators(operators, false, kws, __module__)
expected_type = AbstractOperatorEnum
return esc(
quote
if !isa($(operators), $expected_type)
error("You must pass an operator enum to `@extend_operators`.")
end
$ex
end,
)
return esc(quote
if !isa($(operators), $expected_type)
error("You must pass an operator enum to `@extend_operators`.")
end
$ex
end)
end

"""
Expand All @@ -399,14 +397,12 @@ and `internal` which is default `false`.
macro extend_operators_base(operators, kws...)
ex = _extend_operators(operators, true, kws, __module__)
expected_type = AbstractOperatorEnum
return esc(
quote
if !isa($(operators), $expected_type)
error("You must pass an operator enum to `@extend_operators_base`.")
end
$ex
end,
)
return esc(quote
if !isa($(operators), $expected_type)
error("You must pass an operator enum to `@extend_operators_base`.")
end
$ex
end)
end

"""
Expand Down
2 changes: 1 addition & 1 deletion src/ParametricExpression.jl
Original file line number Diff line number Diff line change
Expand Up @@ -302,7 +302,7 @@ function Base.convert(::Type{Node}, ex::ParametricExpression{T}) where {T}
elseif leaf.is_parameter
Node(T; feature=leaf.parameter)
else
Node(T; feature=leaf.feature + num_params)
Node(T; feature=(leaf.feature + num_params))
end,
branch -> branch.op,
(op, children...) -> Node(; op, children),
Expand Down
14 changes: 7 additions & 7 deletions src/Parse.jl
Original file line number Diff line number Diff line change
Expand Up @@ -95,13 +95,13 @@ macro parse_expression(ex, kws...)
return esc(
:($(parse_expression)(
$(Meta.quot(ex));
operators=$(parsed_kws.operators),
operators=($(parsed_kws.operators)),
binary_operators=nothing,
unary_operators=nothing,
variable_names=$(parsed_kws.variable_names),
node_type=$(parsed_kws.node_type),
expression_type=$(parsed_kws.expression_type),
evaluate_on=$(parsed_kws.evaluate_on),
variable_names=($(parsed_kws.variable_names)),
node_type=($(parsed_kws.node_type)),
expression_type=($(parsed_kws.expression_type)),
evaluate_on=($(parsed_kws.evaluate_on)),
$(parsed_kws.extra_metadata)...,
)),
)
Expand Down Expand Up @@ -188,8 +188,8 @@ end
"You must specify the operators using either `operators`, or `binary_operators` and `unary_operators`"
)
operators = :($(OperatorEnum)(;
binary_operators=$(binops === nothing ? :(Function[]) : binops),
unary_operators=$(unaops === nothing ? :(Function[]) : unaops),
binary_operators=($(binops === nothing ? :(Function[]) : binops)),
unary_operators=($(unaops === nothing ? :(Function[]) : unaops)),
))
else
@assert (binops === nothing && unaops === nothing)
Expand Down
3 changes: 2 additions & 1 deletion src/Random.jl
Original file line number Diff line number Diff line change
Expand Up @@ -42,8 +42,9 @@ end

Sample a node from a tree according to the default sampler `NodeSampler(; tree)`.
"""
rand(rng::AbstractRNG, tree::Union{AbstractNode,AbstractExpression}) =
function rand(rng::AbstractRNG, tree::Union{AbstractNode,AbstractExpression})
rand(rng, NodeSampler(; tree))
end

"""
rand(rng::AbstractRNG, sampler::NodeSampler)
Expand Down
21 changes: 9 additions & 12 deletions src/precompile.jl
Original file line number Diff line number Diff line change
@@ -1,17 +1,15 @@
import PrecompileTools: @compile_workload, @setup_workload

macro ignore_domain_error(ex)
return esc(
quote
try
$ex
catch e
if !(e isa DomainError)
rethrow(e)
end
return esc(quote
try
$ex
catch e
if !(e isa DomainError)
rethrow(e)
end
end,
)
end
end)
end

"""
Expand All @@ -21,8 +19,7 @@ Test all combinations of the given operators and types. Useful for precompilatio
"""
function test_all_combinations(; binary_operators, unary_operators, turbo, types)
for binops in binary_operators,
unaops in unary_operators,
use_turbo in turbo,
unaops in unary_operators, use_turbo in turbo,
T in types

length(binops) == 0 && length(unaops) == 0 && continue
Expand Down
16 changes: 8 additions & 8 deletions test/test_deprecations.jl
Original file line number Diff line number Diff line change
Expand Up @@ -24,23 +24,23 @@ end

if VERSION >= v"1.9"
@test_logs (:warn, r"Node\(d, c, v\) is deprecated.*") (
n = Node(1, true, 1.0 + 0im); @assert (n.val isa ComplexF64)
n=Node(1, true, 1.0 + 0im); @assert (n.val isa ComplexF64)
)
@test_logs (:warn, r"Node\(T, d, c, v\) is deprecated.*") (
n = Node(Float32, 1, true, 1.0 + 0im); @assert (n.val isa Float32)
n=Node(Float32, 1, true, 1.0 + 0im); @assert (n.val isa Float32)
)
@test_logs (:warn, r"Node\(T, d, c, v, f\) is deprecated.*") (
n = Node(Float32, 1, false, nothing, 1); @assert (n.feature == 1)
n=Node(Float32, 1, false, nothing, 1); @assert (n.feature == 1)
)
@test_logs (:warn, r"Node\(d, c, v, f, o, l\) is deprecated.*") (
x1 = Node(; feature=1);
n = Node(1, true, nothing, 1, 3, x1);
x1=Node(; feature=1);
n=Node(1, true, nothing, 1, 3, x1);
@assert (n.op == 3 && n.l === x1)
)
@test_logs (:warn, r"Node\(d, c, v, f, o, l, r\) is deprecated.*") (
x1 = Node(; feature=1);
x2 = Node(; feature=2);
n = Node(2, true, nothing, 1, 1, x1, x2);
x1=Node(; feature=1);
x2=Node(; feature=2);
n=Node(2, true, nothing, 1, 1, x1, x2);
@assert (n.op == 1 && n.l === x1 && n.r === x2)
)
end
Expand Down
15 changes: 9 additions & 6 deletions test/test_evaluation.jl
Original file line number Diff line number Diff line change
Expand Up @@ -103,24 +103,27 @@ end
@test repr(tree) == "cos(cos(3.0))"
tree = convert(Node{T}, tree)
truth = cos(cos(T(3.0f0)))
@test DynamicExpressions.EvaluateModule.deg1_l1_ll0_eval(tree, [zero(T)]', cos, cos, EvalOptions(; turbo)).x[1] ≈
truth
@test DynamicExpressions.EvaluateModule.deg1_l1_ll0_eval(
tree, [zero(T)]', cos, cos, EvalOptions(; turbo)
).x[1] ≈ truth

# op(<constant>, <constant>)
tree = Node(1, Node(; val=3.0f0), Node(; val=4.0f0))
@test repr(tree) == "3.0 + 4.0"
tree = convert(Node{T}, tree)
truth = T(3.0f0) + T(4.0f0)
@test DynamicExpressions.EvaluateModule.deg2_l0_r0_eval(tree, [zero(T)]', (+), EvalOptions(; turbo)).x[1] ≈
truth
@test DynamicExpressions.EvaluateModule.deg2_l0_r0_eval(
tree, [zero(T)]', (+), EvalOptions(; turbo)
).x[1] ≈ truth

# op(op(<constant>, <constant>))
tree = Node(1, Node(1, Node(; val=3.0f0), Node(; val=4.0f0)))
@test repr(tree) == "cos(3.0 + 4.0)"
tree = convert(Node{T}, tree)
truth = cos(T(3.0f0) + T(4.0f0))
@test DynamicExpressions.EvaluateModule.deg1_l2_ll0_lr0_eval(tree, [zero(T)]', cos, (+), EvalOptions(; turbo)).x[1] ≈
truth
@test DynamicExpressions.EvaluateModule.deg1_l2_ll0_lr0_eval(
tree, [zero(T)]', cos, (+), EvalOptions(; turbo)
).x[1] ≈ truth

# Test for presence of NaNs:
operators = OperatorEnum(;
Expand Down
4 changes: 2 additions & 2 deletions test/test_extra_node_fields.jl
Original file line number Diff line number Diff line change
Expand Up @@ -78,8 +78,8 @@ m.frozen = !m.frozen
@test n != m

# Try out an interface for freezing parts of an expression
freeze!(n) = (n.frozen = true; n)
thaw!(n) = (n.frozen = false; n)
freeze!(n) = (n.frozen=true; n)
thaw!(n) = (n.frozen=false; n)

ex = parse_expression(
:(x + $freeze!(sin($thaw!(y + 2.1))));
Expand Down