I’ve written already a little about automatic differentiation with Jax to compute the Hessian and to run Conway’s game of life accelerators (GPUs).

While computing the Jacobian and Hessian matrices is often used in as in intermediate step optimisation algorithms, they are occasionally practically useful in their own right. One such example is the Black-Scholes model where the Jacobian of the present value of a derivative (the “Greeks”) is used to hedge against market movements. Here is a little illustration of computing the Greeks very easily using Jax.

I’m also using Sam Schoenholz’s jax2tex (code at https://github.com/google-research/google-research/tree/master/jax2tex ) to display the results in symbolic form.

The Black-Scholes analytic model

The starting point is the analytic Black-Scholes model (Black & Scholes, 1973). I’ve expressed it in terms of the Black 1976 formula (Black, 1976).

import jax
import jax.scipy
from jax.scipy import special as Sfn
from jax import numpy

def Phi(z):
    return (1+ Sfn.erf(z/jax.numpy.sqrt(2.)))/2

def Black76(cp, F, K, r, sigma, T):
    d1=(numpy.log(F/K) + (sigma**2)/2*T)/ (sigma*numpy.sqrt(T))
    d2=d1-sigma*numpy.sqrt(T)
    if cp=="C":
        return numpy.exp(-r*T)*(F*Phi(d1) - K*Phi(d2))
    else:
        return numpy.exp(-r*T)*(K*Phi(-d2) - F*Phi(-d1))

def BlackScholes(cp, S, K, r, sigma, T):
    F=numpy.exp(r*T)*S
    return Black76(cp, F, K, r, sigma, T)

Note the call/put switch is the first parameter so that the call and put versions are easily generated by partial application:

from functools import partial

def mkCallPut(f):
    c=partial(f, "C")
    c.__name__=f.__name__+"Call"
    p=partial(f, "P")
    p.__name__=f.__name__+"Putt"
    return c,p

BlackScholesCall, BlackScholesPut=mkCallPut(BlackScholes)

The model can now be directly evaluated:

BlackScholesCall( 100, 100, 0.01, 0.05, 1)

returns a value of around 2.52.

With a bit of extension of jax2tex we can also print the symbolic version of the model:

from jax2tex.jax2tex import op2tex, op2ind, noop2ind
from jax2tex import jax2tex

op2tex[jax.lax.erf_p] = lambda x: f"\\mathrm\\left[{x}\\right]"
op2ind[jax.lax.erf_p] = noop2ind

jax2tex(BlackScholesCall, 100, 100, 0.01, 0.05, 1)

We get the following:

\[\text{BlackScholesCall} = e^{\left(-r\right)T}\left(e^{rT}S{\mathrm{ Erf }\left[{ {\log\left({e^{rT}S \over K}\right) + { {sigma}^{2} \over 2.0}T \over sigma\sqrt{T} } \over \sqrt{2.0} }\right] + 1.0 \over 2.0} - K{\mathrm{ Erf }\left[{ {\log\left({e^{rT}S \over K}\right) + { {sigma}^{2} \over 2.0}T \over sigma\sqrt{T} } - sigma\sqrt{T} \over \sqrt{2.0} }\right] + 1.0 \over 2.0}\right)\]

Note the output will contain &= which needs care when rendering in html !

The Greeks

With Jax’s automatic differentiation using the grad function we can obtain the greeks trivially:

Delta, Rho, Vega, mTheta = [grad(BlackScholesCall, argnums=x) for x in [0, 2, 3, 4] ]

(NB. Theta is negative of the conventional form hence the “m” prefix). For example we can evaluate delta simply as:

Delta( 100., 100., 0.01, 0.05, 1.)

Very easy and much reduced possibility for user-introduced errors!

It is possible to also display the symbolic form of the Greeks. It should be noted these are computed using the reverse-mode differentiation and will not necessarily have the simplest analytic form, but can still be useful for spotting potential issues:

\[\text{Delta} = { { { { {-1.0e^{-rT}K \over 2.0}1.128379225730896e^{-{\left({ {\log\left({e^{rT}S \over K}\right) + { {sigma}^{2} \over 2.0}T \over sigma\sqrt{T} } - sigma\sqrt{T} \over 1.4142135381698608}\right)}^{2} } \over 1.4142135381698608} + { {1.0e^{-rT}e^{rT}S \over 2.0}1.128379225730896e^{-{\left({ {\log\left({e^{rT}S \over K}\right) + { {sigma}^{2} \over 2.0}T \over sigma\sqrt{T} } \over 1.4142135381698608}\right)}^{2} } \over 1.4142135381698608} \over sigma\sqrt{T} } \over {e^{rT}S \over K} } \over K} + 1.0e^{-rT}{\mathrm{Erf}\left[{ {\log\left({e^{rT}S \over K}\right) + { {sigma}^{2} \over 2.0}T \over sigma\sqrt{T} } \over 1.4142135381698608}\right] + 1.0 \over 2.0}e^{rT}\] \[\text{Rho} = { { { { {-1.0e^{-rT}K \over 2.0}1.128379225730896e^{-{\left({ {\log\left({e^{rT}S \over K}\right) + { {sigma}^{2} \over 2.0}T \over sigma\sqrt{T}} - sigma\sqrt{T} \over 1.4142135381698608}\right)}^{2}} \over 1.4142135381698608} + { {1.0e^{-rT}e^{rT}S \over 2.0}1.128379225730896e^{-{\left({ {\log\left({e^{rT}S \over K}\right) + { {sigma}^{2} \over 2.0}T \over sigma\sqrt{T}} \over 1.4142135381698608}\right)}^{2}} \over 1.4142135381698608} \over sigma\sqrt{T}} \over {e^{rT}S \over K}} \over K} + 1.0e^{-rT}{\mathrm{Erf}\left[{ {\log\left({e^{rT}S \over K}\right) + { {sigma}^{2} \over 2.0}T \over sigma\sqrt{T}} \over 1.4142135381698608}\right] + 1.0 \over 2.0}Se^{rT}T + -1.0\left(e^{rT}S{\mathrm{Erf}\left[{ {\log\left({e^{rT}S \over K}\right) + { {sigma}^{2} \over 2.0}T \over sigma\sqrt{T}} \over 1.4142135381698608}\right] + 1.0 \over 2.0} - K{\mathrm{Erf}\left[{ {\log\left({e^{rT}S \over K}\right) + { {sigma}^{2} \over 2.0}T \over sigma\sqrt{T}} - sigma\sqrt{T} \over 1.4142135381698608}\right] + 1.0 \over 2.0}\right)e^{-rT}T\] \[\text{Vega} = { { { { {-1.0e^{-rT}K \over 2.0}1.128379225730896e^{-{\left({ {\log\left({e^{rT}S \over K}\right) + { {sigma}^{2} \over 2.0}T \over sigma\sqrt{T}} - sigma\sqrt{T} \over 1.4142135381698608}\right)}^{2}} \over 1.4142135381698608} + { {1.0e^{-rT}e^{rT}S \over 2.0}1.128379225730896e^{-{\left({ {\log\left({e^{rT}S \over K}\right) + { {sigma}^{2} \over 2.0}T \over sigma\sqrt{T}} \over 1.4142135381698608}\right)}^{2}} \over 1.4142135381698608} \over sigma\sqrt{T}} \over {e^{rT}S \over K}} \over K} + 1.0e^{-rT}{\mathrm{Erf}\left[{ {\log\left({e^{rT}S \over K}\right) + { {sigma}^{2} \over 2.0}T \over sigma\sqrt{T}} \over 1.4142135381698608}\right] + 1.0 \over 2.0}Se^{rT}T + -1.0\left(e^{rT}S{\mathrm{Erf}\left[{ {\log\left({e^{rT}S \over K}\right) + { {sigma}^{2} \over 2.0}T \over sigma\sqrt{T}} \over 1.4142135381698608}\right] + 1.0 \over 2.0} - K{\mathrm{Erf}\left[{ {\log\left({e^{rT}S \over K}\right) + { {sigma}^{2} \over 2.0}T \over sigma\sqrt{T}} - sigma\sqrt{T} \over 1.4142135381698608}\right] + 1.0 \over 2.0}\right)e^{-rT}T\]

Summary

Automatic differentiation much reduces the coding (and analysis) requirement when derivatives are needed. This is true both when they are used for optimisation but also in their own right. Jax makes automatic differentiation as easy as a single function call on a Numpy-like function.

The other major advantage of Jax is acceleration – more on that in a forthcoming post!

Originally published 7th September 2020, updated 18th October 2021 (update for changes in jax2tex)

Bibliography

  1. Black, F., & Scholes, M. (1973). The pricing of options and corporate liabilities. Journal of Political Economy, 81(3), 637–654.
  2. Black, F. (1976). The pricing of commodity contracts. Journal of Financial Economics, 3(1-2), 167–179.