/* * R : A Computer Language for Statistical Data Analysis * Copyright (C) 1998-2023 The R Core Team. * Copyright (C) 2004-2017 The R Foundation * Copyright (C) 1995, 1996 Robert Gentleman and Ross Ihaka * * This program is free software; you can redistribute it and/or modify * it under the terms of the GNU General Public License as published by * the Free Software Foundation; either version 2 of the License, or * (at your option) any later version. * * This program is distributed in the hope that it will be useful, * but WITHOUT ANY WARRANTY; without even the implied warranty of * MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the * GNU General Public License for more details. * * You should have received a copy of the GNU General Public License * along with this program; if not, a copy is available at * https://www.R-project.org/Licenses/ * * * Symbolic Differentiation */ #ifdef HAVE_CONFIG_H #include #endif #include "Defn.h" #undef _ #ifdef ENABLE_NLS #include #define _(String) dgettext ("stats", String) #else #define _(String) (String) #endif static SEXP ParenSymbol; static SEXP PlusSymbol; static SEXP MinusSymbol; static SEXP TimesSymbol; static SEXP DivideSymbol; static SEXP PowerSymbol; static SEXP ExpSymbol; static SEXP LogSymbol; static SEXP SinSymbol; static SEXP CosSymbol; static SEXP TanSymbol; static SEXP SinhSymbol; static SEXP CoshSymbol; static SEXP TanhSymbol; static SEXP SqrtSymbol; static SEXP PnormSymbol; static SEXP DnormSymbol; static SEXP AsinSymbol; static SEXP AcosSymbol; static SEXP AtanSymbol; static SEXP GammaSymbol; static SEXP LGammaSymbol; static SEXP DiGammaSymbol; static SEXP TriGammaSymbol; static SEXP PsiSymbol; /* new symbols in R 3.4.0: */ static SEXP PiSymbol; static SEXP ExpM1Symbol; static SEXP Log1PSymbol; static SEXP Log2Symbol; static SEXP Log10Symbol; static SEXP SinPiSymbol; static SEXP CosPiSymbol; static SEXP TanPiSymbol; static SEXP FactorialSymbol; static SEXP LFactorialSymbol; /* possible future symbols static SEXP Log1PExpSymbol; static SEXP Log1MExpSymbol; static SEXP Log1PMxSymbol; */ static Rboolean Initialized = FALSE; static void InitDerivSymbols(void) { /* Called from doD() and deriv() */ if(Initialized) return; ParenSymbol = install("("); PlusSymbol = install("+"); MinusSymbol = install("-"); TimesSymbol = install("*"); DivideSymbol = install("/"); PowerSymbol = install("^"); ExpSymbol = install("exp"); LogSymbol = install("log"); SinSymbol = install("sin"); CosSymbol = install("cos"); TanSymbol = install("tan"); SinhSymbol = install("sinh"); CoshSymbol = install("cosh"); TanhSymbol = install("tanh"); SqrtSymbol = install("sqrt"); PnormSymbol = install("pnorm"); DnormSymbol = install("dnorm"); AsinSymbol = install("asin"); AcosSymbol = install("acos"); AtanSymbol = install("atan"); GammaSymbol = install("gamma"); LGammaSymbol = install("lgamma"); DiGammaSymbol = install("digamma"); TriGammaSymbol = install("trigamma"); PsiSymbol = install("psigamma"); /* new symbols */ PiSymbol = install("pi"); ExpM1Symbol = install("expm1"); Log1PSymbol = install("log1p"); Log2Symbol = install("log2"); Log10Symbol = install("log10"); SinPiSymbol = install("sinpi"); CosPiSymbol = install("cospi"); TanPiSymbol = install("tanpi"); FactorialSymbol = install("factorial"); LFactorialSymbol = install("lfactorial"); /* possible future symbols Log1PExpSymbol = install("log1pexp"); # log(1+exp(x)) Log1MExpSymbol = install("log1mexp"); # log(1-exp(-x)), for x > 0 Log1PMxSymbol = install("log1pmx"); # log1p(x)-x */ Initialized = TRUE; } static SEXP Constant(double x) { return ScalarReal(x); } static int isZero(SEXP s) { return asReal(s) == 0.0; } static int isOne(SEXP s) { return asReal(s) == 1.0; } static int isUminus(SEXP s) { if (TYPEOF(s) == LANGSXP && CAR(s) == MinusSymbol) { switch(length(s)) { case 2: return 1; case 3: if (CADDR(s) == R_MissingArg) return 1; else return 0; default: error(_("invalid form in unary minus check")); return -1;/* for -Wall */ } } else return 0; } /* Pointer protect and return the argument */ static SEXP PP(SEXP s) { PROTECT(s); return s; } static SEXP simplify(SEXP fun, SEXP arg1, SEXP arg2) { SEXP ans; if (fun == PlusSymbol) { if (isZero(arg1)) ans = arg2; else if (isZero(arg2)) ans = arg1; else if (isUminus(arg1)) ans = simplify(MinusSymbol, arg2, CADR(arg1)); else if (isUminus(arg2)) ans = simplify(MinusSymbol, arg1, CADR(arg2)); else ans = lang3(PlusSymbol, arg1, arg2); } else if (fun == MinusSymbol) { if (arg2 == R_MissingArg) { if (isZero(arg1)) ans = Constant(0.); else if (isUminus(arg1)) ans = CADR(arg1); else ans = lang2(MinusSymbol, arg1); } else { if (isZero(arg2)) ans = arg1; else if (isZero(arg1)) ans = simplify(MinusSymbol, arg2, R_MissingArg); else if (isUminus(arg1)) { ans = simplify(MinusSymbol, PP(simplify(PlusSymbol, CADR(arg1), arg2)), R_MissingArg); UNPROTECT(1); } else if (isUminus(arg2)) ans = simplify(PlusSymbol, arg1, CADR(arg2)); else ans = lang3(MinusSymbol, arg1, arg2); } } else if (fun == TimesSymbol) { if (isZero(arg1) || isZero(arg2)) ans = Constant(0.); else if (isOne(arg1)) ans = arg2; else if (isOne(arg2)) ans = arg1; else if (isUminus(arg1)) { ans = simplify(MinusSymbol, PP(simplify(TimesSymbol, CADR(arg1), arg2)), R_MissingArg); UNPROTECT(1); } else if (isUminus(arg2)) { ans = simplify(MinusSymbol, PP(simplify(TimesSymbol, arg1, CADR(arg2))), R_MissingArg); UNPROTECT(1); } else ans = lang3(TimesSymbol, arg1, arg2); } else if (fun == DivideSymbol) { if (isZero(arg1)) ans = Constant(0.); else if (isZero(arg2)) ans = Constant(NA_REAL); else if (isOne(arg2)) ans = arg1; else if (isUminus(arg1)) { ans = simplify(MinusSymbol, PP(simplify(DivideSymbol, CADR(arg1), arg2)), R_MissingArg); UNPROTECT(1); } else if (isUminus(arg2)) { ans = simplify(MinusSymbol, PP(simplify(DivideSymbol, arg1, CADR(arg2))), R_MissingArg); UNPROTECT(1); } else ans = lang3(DivideSymbol, arg1, arg2); } else if (fun == PowerSymbol) { if (isZero(arg2)) ans = Constant(1.); else if (isZero(arg1)) ans = Constant(0.); else if (isOne(arg1)) ans = Constant(1.); else if (isOne(arg2)) ans = arg1; else ans = lang3(PowerSymbol, arg1, arg2); } else if (fun == ExpSymbol) { /* FIXME: simplify exp(lgamma( E )) = gamma( E ) */ /* FIXME: simplify exp(lfactorial( E )) = factorial( E ) */ ans = lang2(ExpSymbol, arg1); } else if (fun == LogSymbol) { /* FIXME: simplify log(gamma( E )) = lgamma( E ) */ /* FIXME: simplify log(factorial( E )) = lfactorial( E ) */ ans = lang2(LogSymbol, arg1); } else if (fun == CosSymbol) ans = lang2(CosSymbol, arg1); else if (fun == SinSymbol) ans = lang2(SinSymbol, arg1); else if (fun == TanSymbol) ans = lang2(TanSymbol, arg1); else if (fun == CoshSymbol) ans = lang2(CoshSymbol, arg1); else if (fun == SinhSymbol) ans = lang2(SinhSymbol, arg1); else if (fun == TanhSymbol) ans = lang2(TanhSymbol, arg1); else if (fun == SqrtSymbol) ans = lang2(SqrtSymbol, arg1); else if (fun == PnormSymbol)ans = lang2(PnormSymbol, arg1); else if (fun == DnormSymbol)ans = lang2(DnormSymbol, arg1); else if (fun == AsinSymbol) ans = lang2(AsinSymbol, arg1); else if (fun == AcosSymbol) ans = lang2(AcosSymbol, arg1); else if (fun == AtanSymbol) ans = lang2(AtanSymbol, arg1); else if (fun == GammaSymbol)ans = lang2(GammaSymbol, arg1); else if (fun == LGammaSymbol)ans = lang2(LGammaSymbol, arg1); else if (fun == DiGammaSymbol) ans = lang2(DiGammaSymbol, arg1); else if (fun == TriGammaSymbol) ans = lang2(TriGammaSymbol, arg1); else if (fun == PsiSymbol){ if (arg2 == R_MissingArg) ans = lang2(PsiSymbol, arg1); else ans = lang3(PsiSymbol, arg1, arg2); } /* new symbols */ else if (fun == ExpM1Symbol) { /* FIXME: simplify expm1(log1p( E )) = E */ ans = lang2(ExpM1Symbol, arg1); } else if (fun == LogSymbol) { /* FIXME: simplify log1p(expm1( E )) = E */ ans = lang2(Log1PSymbol, arg1); } else if (fun == Log2Symbol) ans = lang2(Log2Symbol, arg1); else if (fun == Log10Symbol) ans = lang2(Log10Symbol, arg1); else if (fun == CosPiSymbol) ans = lang2(CosPiSymbol, arg1); else if (fun == SinPiSymbol) ans = lang2(SinPiSymbol, arg1); else if (fun == TanPiSymbol) ans = lang2(TanPiSymbol, arg1); else if (fun == FactorialSymbol)ans = lang2(FactorialSymbol, arg1); else if (fun == LFactorialSymbol)ans = lang2(LFactorialSymbol, arg1); /* possible future symbols else if (fun == Log1PExpSymbol) ans = lang2(Log1PExpSymbol, arg1); else if (fun == Log1MExpSymbol) ans = lang2(Log1MExpSymbol, arg1); else if (fun == Log1PMxSymbol) ans = lang2(Log1PMxSymbol, arg1); */ else ans = Constant(NA_REAL); /* FIXME */ #ifdef NOTYET if (length(ans) == 2 && isAtomic(CADR(ans)) && CAR(ans) != MinusSymbol) c = eval(c, rho); if (length(c) == 3 && isAtomic(CADR(ans)) && isAtomic(CADDR(ans))) c = eval(c, rho); #endif return ans; }/* simplify() */ /* D() implements the "derivative table" : */ static SEXP D(SEXP expr, SEXP var) { #define PP_S(F,a1,a2) PP(simplify(F,a1,a2)) #define PP_S2(F,a1) PP(simplify(F,a1, R_MissingArg)) SEXP ans = R_NilValue, expr1, expr2; switch(TYPEOF(expr)) { case LGLSXP: case INTSXP: case REALSXP: case CPLXSXP: ans = Constant(0); break; case SYMSXP: if (expr == var) ans = Constant(1.); else ans = Constant(0.); break; case LISTSXP: if (inherits(expr, "expression")) ans = D(CAR(expr), var); else ans = Constant(NA_REAL); break; case LANGSXP: if (CAR(expr) == ParenSymbol) { ans = D(CADR(expr), var); } else if (CAR(expr) == PlusSymbol) { if (length(expr) == 2) ans = D(CADR(expr), var); else { ans = simplify(PlusSymbol, PP(D(CADR(expr), var)), PP(D(CADDR(expr), var))); UNPROTECT(2); } } else if (CAR(expr) == MinusSymbol) { if (length(expr) == 2) { ans = simplify(MinusSymbol, PP(D(CADR(expr), var)), R_MissingArg); UNPROTECT(1); } else { ans = simplify(MinusSymbol, PP(D(CADR(expr), var)), PP(D(CADDR(expr), var))); UNPROTECT(2); } } else if (CAR(expr) == TimesSymbol) { ans = simplify(PlusSymbol, PP_S(TimesSymbol,PP(D(CADR(expr),var)), CADDR(expr)), PP_S(TimesSymbol,CADR(expr), PP(D(CADDR(expr),var)))); UNPROTECT(4); } else if (CAR(expr) == DivideSymbol) { PROTECT(expr1 = D(CADR(expr), var)); PROTECT(expr2 = D(CADDR(expr), var)); ans = simplify(MinusSymbol, PP_S(DivideSymbol, expr1, CADDR(expr)), PP_S(DivideSymbol, PP_S(TimesSymbol, CADR(expr), expr2), PP_S(PowerSymbol,CADDR(expr),PP(Constant(2.))))); UNPROTECT(7); } else if (CAR(expr) == PowerSymbol) { if (isLogical(CADDR(expr)) || isNumeric(CADDR(expr))) { ans = simplify(TimesSymbol, CADDR(expr), PP_S(TimesSymbol, PP(D(CADR(expr), var)), PP_S(PowerSymbol, CADR(expr), PP(Constant(asReal(CADDR(expr))-1.))))); UNPROTECT(4); } else { expr1 = simplify(TimesSymbol, PP_S(PowerSymbol, CADR(expr), PP_S(MinusSymbol, CADDR(expr), PP(Constant(1.0)))), PP_S(TimesSymbol, CADDR(expr), PP(D(CADR(expr), var)))); UNPROTECT(5); PROTECT(expr1); expr2 = simplify(TimesSymbol, PP_S(PowerSymbol, CADR(expr), CADDR(expr)), PP_S(TimesSymbol, PP_S2(LogSymbol, CADR(expr)), PP(D(CADDR(expr), var)))); UNPROTECT(4); PROTECT(expr2); ans = simplify(PlusSymbol, expr1, expr2); UNPROTECT(2); } } else if (CAR(expr) == ExpSymbol) { ans = simplify(TimesSymbol, expr, PP(D(CADR(expr), var))); UNPROTECT(1); } else if (CAR(expr) == LogSymbol) { if (length(expr) != 2) error("only single-argument calls to log() are supported;\n" " maybe use log(x,a) = log(x)/log(a)"); ans = simplify(DivideSymbol, PP(D(CADR(expr), var)), CADR(expr)); UNPROTECT(1); } else if (CAR(expr) == CosSymbol) { ans = simplify(TimesSymbol, PP_S2(SinSymbol, CADR(expr)), PP_S2(MinusSymbol, PP(D(CADR(expr), var)))); UNPROTECT(3); } else if (CAR(expr) == SinSymbol) { ans = simplify(TimesSymbol, PP_S2(CosSymbol, CADR(expr)), PP(D(CADR(expr), var))); UNPROTECT(2); } else if (CAR(expr) == TanSymbol) { ans = simplify(DivideSymbol, PP(D(CADR(expr), var)), PP_S(PowerSymbol, PP_S2(CosSymbol, CADR(expr)), PP(Constant(2.0)))); UNPROTECT(4); } else if (CAR(expr) == CoshSymbol) { ans = simplify(TimesSymbol, PP_S2(SinhSymbol, CADR(expr)), PP(D(CADR(expr), var))); UNPROTECT(2); } else if (CAR(expr) == SinhSymbol) { ans = simplify(TimesSymbol, PP_S2(CoshSymbol, CADR(expr)), PP(D(CADR(expr), var))), UNPROTECT(2); } else if (CAR(expr) == TanhSymbol) { ans = simplify(DivideSymbol, PP(D(CADR(expr), var)), PP_S(PowerSymbol, PP_S2(CoshSymbol, CADR(expr)), PP(Constant(2.0)))); UNPROTECT(4); } else if (CAR(expr) == SqrtSymbol) { PROTECT(expr1 = allocList(3)); SET_TYPEOF(expr1, LANGSXP); SETCAR(expr1, PowerSymbol); SETCADR(expr1, CADR(expr)); SETCADDR(expr1, Constant(0.5)); ans = D(expr1, var); UNPROTECT(1); } else if (CAR(expr) == PnormSymbol) { ans = simplify(TimesSymbol, PP_S2(DnormSymbol, CADR(expr)), PP(D(CADR(expr), var))); UNPROTECT(2); } else if (CAR(expr) == DnormSymbol) { ans = simplify(TimesSymbol, PP_S2(MinusSymbol, CADR(expr)), PP_S(TimesSymbol, PP_S2(DnormSymbol, CADR(expr)), PP(D(CADR(expr), var)))); UNPROTECT(4); } else if (CAR(expr) == AsinSymbol) { ans = simplify(DivideSymbol, PP(D(CADR(expr), var)), PP_S(SqrtSymbol, PP_S(MinusSymbol, PP(Constant(1.)), PP_S(PowerSymbol,CADR(expr),PP(Constant(2.)))), R_MissingArg)); UNPROTECT(6); } else if (CAR(expr) == AcosSymbol) { ans = simplify(MinusSymbol, PP_S(DivideSymbol, PP(D(CADR(expr), var)), PP_S(SqrtSymbol, PP_S(MinusSymbol, PP(Constant(1.)), PP_S(PowerSymbol, CADR(expr),PP(Constant(2.)))), R_MissingArg)), R_MissingArg); UNPROTECT(7); } else if (CAR(expr) == AtanSymbol) { ans = simplify(DivideSymbol, PP(D(CADR(expr), var)), PP_S(PlusSymbol,PP(Constant(1.)), PP_S(PowerSymbol, CADR(expr),PP(Constant(2.))))); UNPROTECT(5); } else if (CAR(expr) == LGammaSymbol) { ans = simplify(TimesSymbol, PP(D(CADR(expr), var)), PP_S2(DiGammaSymbol, CADR(expr))); UNPROTECT(2); } else if (CAR(expr) == GammaSymbol) { ans = simplify(TimesSymbol, PP(D(CADR(expr), var)), PP_S(TimesSymbol, expr, PP_S2(DiGammaSymbol, CADR(expr)))); UNPROTECT(3); } else if (CAR(expr) == DiGammaSymbol) { ans = simplify(TimesSymbol, PP(D(CADR(expr), var)), PP_S2(TriGammaSymbol, CADR(expr))); UNPROTECT(2); } else if (CAR(expr) == TriGammaSymbol) { ans = simplify(TimesSymbol, PP(D(CADR(expr), var)), PP_S(PsiSymbol, CADR(expr), PP(ScalarInteger(2)))); UNPROTECT(3); } else if (CAR(expr) == PsiSymbol) { if (length(expr) == 2){ ans = simplify(TimesSymbol, PP(D(CADR(expr), var)), PP_S(PsiSymbol, CADR(expr), PP(ScalarInteger(1)))); UNPROTECT(3); } else if (TYPEOF(CADDR(expr)) == INTSXP || TYPEOF(CADDR(expr)) == REALSXP) { ans = simplify(TimesSymbol, PP(D(CADR(expr), var)), PP_S(PsiSymbol, CADR(expr), PP(ScalarInteger(asInteger(CADDR(expr))+1)))); UNPROTECT(3); } else { ans = simplify(TimesSymbol, PP(D(CADR(expr), var)), PP_S(PsiSymbol, CADR(expr), simplify(PlusSymbol, CADDR(expr), PP(ScalarInteger(1))))); UNPROTECT(3); } } /* new in R 3.4.0 */ else if (CAR(expr) == ExpM1Symbol) { ans = simplify(TimesSymbol, PP_S2(ExpSymbol, CADR(expr)), PP(D(CADR(expr), var))); UNPROTECT(2); } else if (CAR(expr) == Log1PSymbol) { ans = simplify(DivideSymbol, PP(D(CADR(expr), var)), PP_S(PlusSymbol, PP(Constant(1.)), CADR(expr))); UNPROTECT(3); } else if (CAR(expr) == Log2Symbol) { ans = simplify(DivideSymbol, PP(D(CADR(expr), var)), PP_S(TimesSymbol, CADR(expr), PP_S2(LogSymbol, PP(Constant(2.))))); UNPROTECT(4); } else if (CAR(expr) == Log10Symbol) { ans = simplify(DivideSymbol, PP(D(CADR(expr), var)), PP_S(TimesSymbol, CADR(expr), PP_S2(LogSymbol, PP(Constant(10.))))); UNPROTECT(4); } else if (CAR(expr) == CosPiSymbol) { ans = simplify(TimesSymbol, PP_S2(SinPiSymbol, CADR(expr)), PP_S(TimesSymbol, PP_S2(MinusSymbol, PiSymbol), PP(D(CADR(expr), var)) )); UNPROTECT(4); } else if (CAR(expr) == SinPiSymbol) { ans = simplify(TimesSymbol, PP_S2(CosPiSymbol, CADR(expr)), PP_S(TimesSymbol, PiSymbol, PP(D(CADR(expr), var)) )); UNPROTECT(3); } else if (CAR(expr) == TanPiSymbol) { ans = simplify(DivideSymbol, PP_S(TimesSymbol, PiSymbol, PP(D(CADR(expr), var))), PP_S(PowerSymbol, PP_S2(CosPiSymbol, CADR(expr)), PP(Constant(2.0)))); UNPROTECT(5); } else if (CAR(expr) == LFactorialSymbol) { ans = simplify(TimesSymbol, PP(D(CADR(expr), var)), PP_S2(DiGammaSymbol, PP_S(PlusSymbol, CADR(expr), PP(ScalarInteger(1))))); UNPROTECT(4); } else if (CAR(expr) == FactorialSymbol) { ans = simplify(TimesSymbol, PP(D(CADR(expr), var)), PP_S(TimesSymbol, expr, PP_S2(DiGammaSymbol, PP_S(PlusSymbol, CADR(expr), PP(ScalarInteger(1)))))); UNPROTECT(5); } /* possible future symbols else if (CAR(expr) == Log1PExpSymbol) { ans = simplify(DivideSymbol, PP_S(TimesSymbol, PP(D(CADR(expr), var)), PP_S2(ExpSymbol, CADR(expr))), PP_S(PlusSymbol,PP(Constant(1.)), PP_S2(ExpSymbol, CADR(expr)) )); UNPROTECT(6); } else if (CAR(expr) == Log1MExpSymbol) { ans = simplify(DivideSymbol, PP_S(TimesSymbol, PP_S2(MinusSymbol, PP(D(CADR(expr), var))), PP_S2(ExpSymbol, PP_S2(MinusSymbol, CADR(expr))) ), PP_S2(ExpM1Symbol, PP_S2(MinusSymbol, CADR(expr))) ); UNPROTECT(7); } else if (CAR(expr) == Log1PMxSymbol) { ans = simplify(DivideSymbol, PP_S2(MinusSymbol, PP(D(CADR(expr), var))), PP_S(PlusSymbol,PP(Constant(1.)), CADR(expr)) ); UNPROTECT(4); } */ else { SEXP u = deparse1(CAR(expr), 0, SIMPLEDEPARSE); error(_("Function '%s' is not in the derivatives table"), translateChar(STRING_ELT(u, 0))); } break; default: ans = Constant(NA_REAL); } return ans; #undef PP_S #undef PP_S2 } /* D() */ static int isPlusForm(SEXP expr) { return TYPEOF(expr) == LANGSXP && length(expr) == 3 && CAR(expr) == PlusSymbol; } static int isMinusForm(SEXP expr) { return TYPEOF(expr) == LANGSXP && length(expr) == 3 && CAR(expr) == MinusSymbol; } static int isTimesForm(SEXP expr) { return TYPEOF(expr) == LANGSXP && length(expr) == 3 && CAR(expr) == TimesSymbol; } static int isDivideForm(SEXP expr) { return TYPEOF(expr) == LANGSXP && length(expr) == 3 && CAR(expr) == DivideSymbol; } static int isPowerForm(SEXP expr) { return (TYPEOF(expr) == LANGSXP && length(expr) == 3 && CAR(expr) == PowerSymbol); } static SEXP AddParens(SEXP expr) { SEXP e; if (TYPEOF(expr) == LANGSXP) { e = CDR(expr); while(e != R_NilValue) { SETCAR(e, AddParens(CAR(e))); e = CDR(e); } } if (isPlusForm(expr)) { if (isPlusForm(CADDR(expr))) { SETCADDR(expr, lang2(ParenSymbol, CADDR(expr))); } } else if (isMinusForm(expr)) { if (isPlusForm(CADDR(expr)) || isMinusForm(CADDR(expr))) { SETCADDR(expr, lang2(ParenSymbol, CADDR(expr))); } } else if (isTimesForm(expr)) { if (isPlusForm(CADDR(expr)) || isMinusForm(CADDR(expr)) || isTimesForm(CADDR(expr)) || isDivideForm(CADDR(expr))) { SETCADDR(expr, lang2(ParenSymbol, CADDR(expr))); } if (isPlusForm(CADR(expr)) || isMinusForm(CADR(expr))) { SETCADR(expr, lang2(ParenSymbol, CADR(expr))); } } else if (isDivideForm(expr)) { if (isPlusForm(CADDR(expr)) || isMinusForm(CADDR(expr)) || isTimesForm(CADDR(expr)) || isDivideForm(CADDR(expr))) { SETCADDR(expr, lang2(ParenSymbol, CADDR(expr))); } if (isPlusForm(CADR(expr)) || isMinusForm(CADR(expr))) { SETCADR(expr, lang2(ParenSymbol, CADR(expr))); } } else if (isPowerForm(expr)) { if (isPowerForm(CADR(expr))) { SETCADR(expr, lang2(ParenSymbol, CADR(expr))); } if (isPlusForm(CADDR(expr)) || isMinusForm(CADDR(expr)) || isTimesForm(CADDR(expr)) || isDivideForm(CADDR(expr))) { SETCADDR(expr, lang2(ParenSymbol, CADDR(expr))); } } return expr; } SEXP doD(SEXP args) { args = CDR(args); SEXP expr; if (isExpression(CAR(args))) expr = VECTOR_ELT(CAR(args), 0); else expr = CAR(args); if (!(isLanguage(expr) || isSymbol(expr) || isNumeric(expr) || isComplex(expr))) error(_("expression must not be type '%s'"), R_typeToChar(expr)); SEXP var = CADR(args); if (!isString(var) || length(var) < 1) error(_("variable must be a character string")); if (length(var) > 1) warning(_("only the first element is used as variable name")); var = installTrChar(STRING_ELT(var, 0)); InitDerivSymbols(); PROTECT(expr = D(expr, var)); expr = AddParens(expr); UNPROTECT(1); return expr; } /* ------ FindSubexprs ------ and ------ Accumulate ------ */ NORET static void InvalidExpression(char *where) { error(_("invalid expression in '%s'"), where); } static int equal(SEXP expr1, SEXP expr2) { if (TYPEOF(expr1) == TYPEOF(expr2)) { switch(TYPEOF(expr1)) { case NILSXP: return 1; case SYMSXP: return expr1 == expr2; case LGLSXP: case INTSXP: return INTEGER(expr1)[0] == INTEGER(expr2)[0]; case REALSXP: return REAL(expr1)[0] == REAL(expr2)[0]; case CPLXSXP: return COMPLEX(expr1)[0].r == COMPLEX(expr2)[0].r && COMPLEX(expr1)[0].i == COMPLEX(expr2)[0].i; case LANGSXP: case LISTSXP: return equal(CAR(expr1), CAR(expr2)) && equal(CDR(expr1), CDR(expr2)); default: InvalidExpression("equal"); } } return 0; } static int Accumulate(SEXP expr, SEXP exprlist) { SEXP e; int k; e = exprlist; k = 0; while(CDR(e) != R_NilValue) { e = CDR(e); k = k + 1; if (equal(expr, CAR(e))) return k; } SETCDR(e, CONS(expr, R_NilValue)); return k + 1; } static int Accumulate2(SEXP expr, SEXP exprlist) { SEXP e; int k; e = exprlist; k = 0; while(CDR(e) != R_NilValue) { e = CDR(e); k = k + 1; } SETCDR(e, CONS(expr, R_NilValue)); return k + 1; } static SEXP MakeVariable(int k, const char *tag) { char buf[64]; int res = snprintf(buf, 64, "%s%d", tag, k); if (res >= 64) error(_("too many variables")); return install(buf); } static int FindSubexprs(SEXP expr, SEXP exprlist, const char *tag) { SEXP e; int k; switch(TYPEOF(expr)) { case SYMSXP: case LGLSXP: case INTSXP: case REALSXP: case CPLXSXP: return 0; break; case LISTSXP: if (inherits(expr, "expression")) return FindSubexprs(CAR(expr), exprlist, tag); else { InvalidExpression("FindSubexprs"); return -1/*-Wall*/; } break; case LANGSXP: if (CAR(expr) == install("(")) { return FindSubexprs(CADR(expr), exprlist, tag); } else { e = CDR(expr); while(e != R_NilValue) { if ((k = FindSubexprs(CAR(e), exprlist, tag)) != 0) SETCAR(e, MakeVariable(k, tag)); e = CDR(e); } return Accumulate(expr, exprlist); } break; default: InvalidExpression("FindSubexprs"); return -1/*-Wall*/; } } static int CountOccurrences(SEXP sym, SEXP lst) { switch(TYPEOF(lst)) { case SYMSXP: return lst == sym; case LISTSXP: case LANGSXP: return CountOccurrences(sym, CAR(lst)) + CountOccurrences(sym, CDR(lst)); default: return 0; } } static SEXP Replace(SEXP sym, SEXP expr, SEXP lst) { switch(TYPEOF(lst)) { case SYMSXP: if (lst == sym) return expr; else return lst; case LISTSXP: case LANGSXP: SETCAR(lst, Replace(sym, expr, CAR(lst))); SETCDR(lst, Replace(sym, expr, CDR(lst))); return lst; default: return lst; } } static SEXP CreateGrad(SEXP names) { SEXP p, q, data, dim, dimnames; int i, n; n = length(names); PROTECT(dimnames = lang3(R_NilValue, R_NilValue, R_NilValue)); SETCAR(dimnames, install("list")); p = install("c"); PROTECT(q = allocList(n)); SETCADDR(dimnames, LCONS(p, q)); UNPROTECT(1); for(i = 0 ; i < n ; i++) { SETCAR(q, ScalarString(STRING_ELT(names, i))); q = CDR(q); } PROTECT(dim = lang3(R_NilValue, R_NilValue, R_NilValue)); SETCAR(dim, install("c")); SETCADR(dim, lang2(install("length"), install(".value"))); SETCADDR(dim, ScalarInteger(length(names))); /* was real? */ PROTECT(data = ScalarReal(0.)); PROTECT(p = lang4(install("array"), data, dim, dimnames)); p = lang3(install("<-"), install(".grad"), p); UNPROTECT(4); return p; } static SEXP CreateHess(SEXP names) { SEXP p, q, data, dim, dimnames; int i, n; n = length(names); PROTECT(dimnames = lang4(R_NilValue, R_NilValue, R_NilValue, R_NilValue)); SETCAR(dimnames, install("list")); p = install("c"); PROTECT(q = allocList(n)); SETCADDR(dimnames, LCONS(p, q)); UNPROTECT(1); for(i = 0 ; i < n ; i++) { SETCAR(q, ScalarString(STRING_ELT(names, i))); q = CDR(q); } SETCADDDR(dimnames, duplicate(CADDR(dimnames))); PROTECT(dim = lang4(R_NilValue, R_NilValue, R_NilValue,R_NilValue)); SETCAR(dim, install("c")); SETCADR(dim, lang2(install("length"), install(".value"))); SETCADDR(dim, ScalarInteger(length(names))); SETCADDDR(dim, ScalarInteger(length(names))); PROTECT(data = ScalarReal(0.)); PROTECT(p = lang4(install("array"), data, dim, dimnames)); p = lang3(install("<-"), install(".hessian"), p); UNPROTECT(4); return p; } static SEXP DerivAssign(SEXP name, SEXP expr) { SEXP ans, newname; PROTECT(ans = lang3(install("<-"), R_NilValue, expr)); PROTECT(newname = ScalarString(name)); SETCADR(ans, lang4(R_BracketSymbol, install(".grad"), R_MissingArg, newname)); UNPROTECT(2); return ans; } static SEXP HessAssign1(SEXP name, SEXP expr) { SEXP ans, newname; PROTECT(ans = lang3(install("<-"), R_NilValue, expr)); PROTECT(newname = ScalarString(name)); SETCADR(ans, lang5(R_BracketSymbol, install(".hessian"), R_MissingArg, newname, newname)); UNPROTECT(2); return ans; } static SEXP HessAssign2(SEXP name1, SEXP name2, SEXP expr) { SEXP ans, newname1, newname2, tmp1, tmp2, tmp3; PROTECT(newname1 = ScalarString(name1)); PROTECT(newname2 = ScalarString(name2)); /* this is overkill, but PR#14772 found an issue */ PROTECT(tmp1 = lang5(R_BracketSymbol, install(".hessian"), R_MissingArg, newname1, newname2)); PROTECT(tmp2 = lang5(R_BracketSymbol, install(".hessian"), R_MissingArg, newname2, newname1)); PROTECT(tmp3 = lang3(install("<-"), tmp2, expr)); ans = lang3(install("<-"), tmp1, tmp3); UNPROTECT(5); return ans; } /* attr(.value, "gradient") <- .grad */ static SEXP AddGrad(void) { SEXP ans; PROTECT(ans = mkString("gradient")); PROTECT(ans = lang3(install("attr"), install(".value"), ans)); ans = lang3(install("<-"), ans, install(".grad")); UNPROTECT(2); return ans; } static SEXP AddHess(void) { SEXP ans; PROTECT(ans = mkString("hessian")); PROTECT(ans = lang3(install("attr"), install(".value"), ans)); ans = lang3(install("<-"), ans, install(".hessian")); UNPROTECT(2); return ans; } static SEXP Prune(SEXP lst) { if (lst == R_NilValue) return lst; SETCDR(lst, Prune(CDR(lst))); if (CAR(lst) == R_MissingArg) return CDR(lst); else return lst ; } SEXP deriv(SEXP args) { /* deriv(expr, namevec, function.arg, tag, hessian) */ SEXP ans, ans2, expr, funarg, names, s; int f_index, *d_index, *d2_index; int i, j, k, nexpr, nderiv=0, hessian; SEXP exprlist, stag; const void *vmax = vmaxget(); const char *tag; args = CDR(args); InitDerivSymbols(); PROTECT(exprlist = LCONS(R_BraceSymbol, R_NilValue)); /* expr: */ if (isExpression(CAR(args))) PROTECT(expr = VECTOR_ELT(CAR(args), 0)); else PROTECT(expr = CAR(args)); args = CDR(args); /* namevec: */ names = CAR(args); if (!isString(names) || (nderiv = length(names)) < 1) error(_("invalid variable names")); args = CDR(args); /* function.arg: */ funarg = CAR(args); args = CDR(args); /* tag: */ stag = CAR(args); if (!isString(stag) || length(stag) < 1 || length(STRING_ELT(stag, 0)) < 1) error(_("invalid tag")); tag = translateChar(STRING_ELT(stag, 0)); if(strlen(tag) > 60) error(_("invalid tag")); args = CDR(args); /* hessian: */ hessian = asLogical(CAR(args)); /* NOTE: FindSubexprs is destructive, hence the duplication. It can allocate, so protect the duplicate. */ PROTECT(ans = duplicate(expr)); f_index = FindSubexprs(ans, exprlist, tag); d_index = (int*)R_alloc((size_t) nderiv, sizeof(int)); if (hessian) d2_index = (int*)R_alloc((size_t) ((nderiv * (1 + nderiv))/2), sizeof(int)); else d2_index = d_index;/*-Wall*/ UNPROTECT(1); for(i=0, k=0; i