16 de junio de 2015

Optimizando “Send More Money” en Racket

El problema

Estuve leyendo algunos artículos que discutían el problema de encontrar por fuerza bruta dígitos distintos para que valga:

  SEND
+ MORE
------
 MONEY
Vamos a ver cómo resolverlo en Racket, siguiendo la idea de sólo usar fuerza bruta pero ignorando completamente los monads. La idea es ir haciendo pequeños cambios en el programa y ver cómo afectan el tiempo de ejecución.

Usando astutamente unas macros, podemos hacer que el programa vaya unas x12 o x14 veces más rápido "casi" sin modificar el programa (para una definición laxa de "casi".)

(En este artículo, voy a remarcar las macros que agrego poniéndoles m:algo en el nombre para destacarlas. En particular, porque reemplazan a funciones usuales de Racket. En general, las macros no se destacan y simplemente se mezclan con las funciones.)

En los ejemplos voy a comparar los tiempos de ejecución en mi computadora de escritorio y en mi netbook. En cada caso, se muestra el promedio de 5 ejecuciones del programa.

Traducción literal de Haskel

La primera versión es casi una copia textual de la versión en Haskel, cambiando el monad list por un for. (Es muy parecido a la versión propuesta por minikomi en HN. En mi primera versión, yo no recordaba que existía la función remove* de Racket y la había vuelto a programar.)

Tiempo de ejecución promedio (5 corridas, en milisegundos):

CPU GC CPU GC
literal 2648,8 58,8 11051,0 174,6

Sacamos como conclusión que mi netbook es una 4 veces más lenta que mi computadora de escritorio :). Podríamos compararlo con los tiempos de ejecución en otros lenguajes, pero es claro que influye mucho la velocidad de la computadora.


#lang racket
(define digits '(0 1 2 3 4 5 6 7 8 9))

(define (to-number digits)
  (foldl {lambda (x y) (+ (* 10 y) x)} 0 digits))

(time (for*/list ([s (remove* (list 0) digits)]
                  [e (remove* (list s) digits)]
                  [n (remove* (list s e) digits)]
                  [d (remove* (list s e n) digits)]
                  [send (in-value (to-number (list s e n d)))]
                  [m (remove* (list 0 s e n d) digits)]
                  [o (remove* (list s e n d m) digits)]
                  [r (remove* (list s e n d m o) digits)]
                  [more (in-value (to-number (list m o r e)))]
                  [y (remove* (list s e n d m o r) digits)]
                  [money (in-value (to-number (list m o n e y)))]
                  #:when (= (+ send more) money))
        (list send more money)))

Usando in-list para indicar que son listas

A diferencia de Haskel, en Racket las variables no tienen un tipo fijo y no hay inferencia de tipos. (Me gustaría ver cómo anda el programa anterior en Typed Racket.)

Así que es conveniente usar in-list para darle una "pista" a los for de que el resultado de remove* es una lista. En realidad in-list no es una "pista", sinó que tiene las instrucciones para recorrer una lista eficientemente. Sin esta "pista", el for usa instrucciones genéricas que son un poquito más lentas.

Tiempo de ejecución promedio (5 corridas, en milisegundos):

CPU GC CPU GC
literal 2648,8 58,8 11051,0 174,6
in-list 2377,4 15,4 9185,4 62,6

El tiempo de ejecución bajó un 10%-20%. El tiempo de GC bajó un 60%, probablemente porque ya no se necesita crear las estructuras intermedias de las secuencias genéricas que se usan para los ciclos.

Notar que en valor absoluto, el tiempo de CPU bajó mucho más de lo que bajó el tiempo de GC. Una parte muy chica del tiempo se pierde en el GC.


(time (for*/list ([s (in-list (remove* (list 0) digits))]
                  [e (in-list (remove* (list s) digits))]
                  [n (in-list (remove* (list s e) digits))]
                  [d (in-list (remove* (list s e n) digits))]
                  [send (in-value (to-number (list s e n d)))]
                  [m (in-list (remove* (list 0 s e n d) digits))]
                  [o (in-list (remove* (list s e n d m) digits))]
                  [r (in-list (remove* (list s e n d m o) digits))]
                  [more (in-value (to-number (list m o r e)))]
                  [y (in-list (remove* (list s e n d m o r) digits))]
                  [money (in-value (to-number (list m o n e y)))]
                  #:when (= (+ send more) money))
        (list send more money)))

Cambiando equal? por =

El codigo anterior usa muchos equal?s. Lo peor es que son equal?s invisibles. Resulta que (remove* list proc) es equivalente a (remove list proc equal?). Entonces los cambiamos por (remove list proc =). El problema es que equal? puede llamar código arbitrario, por ejemplo para ver que dos structs son equal?. Por esto el código con equal? es más lento y difícil de optimizar. En cambio = llama siempre a una comparación simple.

Tiempo de ejecución promedio (5 corridas, en milisegundos):

CPU GC CPU GC
in-list 2377,4 15,4 9185,4 62,6
= 1507,0 28,2 5559,8 62,8

El tiempo de ejecución bajó un 40%. El tiempo de GC subió un poco. Resulta que usar equal? es mucho más lento que =.


(time (for*/list ([s (in-list (remove* (list 0) digits =))]
                  [e (in-list (remove* (list s) digits =))]
                  [n (in-list (remove* (list s e) digits =))]
                  [d (in-list (remove* (list s e n) digits =))]
                  [send (in-value (to-number (list s e n d)))]
                  [m (in-list (remove* (list 0 s e n d) digits =))]
                  [o (in-list (remove* (list s e n d m) digits =))]
                  [r (in-list (remove* (list s e n d m o) digits =))]
                  [more (in-value (to-number (list m o r e)))]
                  [y (in-list (remove* (list s e n d m o r) digits =))]
                  [money (in-value (to-number (list m o n e y)))]
                  #:when (= (+ send more) money))
        (list send more money)))

Cambiando remove* por #:when

El problema es que los remove* crean en cada pasada una nueva lista sin los números que no queremos que se repitan. Esto crea muchas muchas muchas listas. El GC casi no pierde tiempo porque para cuando se ejecutan casi todas estas listas están sin referencias y simplemente desaparecen. Sin embargo, las listas se van creando en distintas páginas de la memoria y eso hace que sea lento ir a buscarlas.

Entonces reemplazamos los remove* por #:when. En vez de sacar los elementos de la lista antes de elegir el número, revisamos que no estén repetidos después de elegirlo.

Tiempo de ejecución promedio (5 corridas, en milisegundos):

CPU GC CPU GC
= 1507,0 28,2 5559,8 62,8
#:when 1766,0 40,8 6071,8 187,4

El tiempo de ejecución subió un 10%-20% :(. El tiempo de GC subió x2-x3 :(2.

Supongo que porque hay más listas intermedias. Todavía no entiendo los detalles de esto. Aunque el tiempo empeoró, todas las funciones que aparecen son más simples, así que va a ser más fácil aplicar la próxima trasformación.


(time (for*/list ([s (in-list digits)]
                  #:when (not (ormap {lambda (x) (= x s)} (list 0)))
                  [e (in-list digits)]
                  #:when (not (ormap {lambda (x) (= x e)} (list s)))
                  [n (in-list digits)]
                  #:when (not (ormap {lambda (x) (= x n)} (list s e)))
                  [d (in-list digits)]
                  #:when (not (ormap {lambda (x) (= x d)} (list s e n)))
                  [send (in-value (to-number (list s e n d)))]
                  [m (in-list digits)]
                  #:when (not (ormap {lambda (x) (= x m)} (list 0 s e n d)))
                  [o (in-list digits)]
                  #:when (not (ormap {lambda (x) (= x o)} (list s e n d m)))
                  [r (in-list digits)]
                  #:when (not (ormap {lambda (x) (= x r)} (list s e n d m o)))
                  [more (in-value (to-number (list m o r e)))]
                  [y (in-list digits)]
                  #:when (not (ormap {lambda (x) (= x y)} (list s e n d m o r)))
                  [money (in-value (to-number (list m o n e y)))]
                  #:when (= (+ send more) money))
        (list send more money)))

Usando macros en vez de funciones de orden superior

En la versión anterior hay muchas closures {lambda (x) (= x y)} que parecen crearse en memoria. Sin embargo Racket tiene muchos trucos para transformarlas internamente en funciones normales o directamente eliminarlas.

En este caso, el compilador las inlinea. En realidad en este caso primero inlinea ormap y después inlinea las {lambda ...}, así que en el código que se ejecuta no están estas closures y ni siquiera aparecen como llamados a funciones porque el código ya está inlineado.

Los que están causando problemas son ese montón de (list ...) que crean muchas listas que se destruyen casi inmediatamente. Entonces podemos reemplazar ormap (que es una función) por una nueva macro m:ormap.


(define-syntax m:ormap
  (syntax-rules (list)
    [(m:ormap proc (list n ...)) (let ([proc-val proc])
                                   (or (proc-val n) ...))]))

Como m:ormap es una macro puede hacer cosas más raras, como espiar las expresiones en las posiciones de los parámetros y ver que el segundo es una lista en forma explicita de la forma (list ...) y aplicar la función proc directamente a cada uno de los elementos sin crear la lista en ningún momento.

Acá hay un detalle técnico. Necesitamos crear una variable intermedia proc-val  para que contenga el valor de proc. Si usamos directamente proc, la expresión de proc se repetiría y entonces el código que genera proc se ejecutaría varias veces. Eso podría causar resultados inesperados. Por ejemplo, comparemos los resultados cuándo usamos la versión errónea:


(define-syntax m:ormap/bad
  (syntax-rules (list)
    [(m:ormap proc (list n ...)) (or (proc n) ...)]))

(ormap (begin (display "*") {lambda (x) (display x) #f}) (list 1 2 3))
(m:ormap (begin (display "*") {lambda (x) (display x) #f}) (list 1 2 3))
(m:ormap/bad (begin (display "*") {lambda (x) (display x) #f}) (list 1 2 3))
(newline)

(ormap (let ([r (random 6)]) {lambda (x) (display r) #f}) (list 1 2 3))
(m:ormap (let ([r (random 6)]) {lambda (x) (display r) #f}) (list 1 2 3))
(m:ormap/bad (let ([r (random 6)]) {lambda (x) (display r) #f}) (list 1 2 3))
(newline)

En general es bueno que las macro sean lo más intuitivas posible porque en Racket no se acostumbra indicar en el nombre que son macros, así que hay que tratar que no tengan comportamiento sorpresivo. En este caso, lo ideal sería que m:ormap se comporte lo más parecido posible a ormap. Que sea casi como una función, sólo con la magia de espiar adentro de las listas. Por eso es mejor que evalúe el primer parámetro sólo una vez.

Otro detalle importante es que si la función es sencilla el compilador de Racket la puede inlinear. En este caso las {lambda (x) (= x n)} se inlinean. Y como siempre está inlineada la variable auxiliar p_ desaparece durante la compilación. Entonces
(ormap {lambda (x) (= x n)} (list s e))
Queda equivalente a
(or (= s n) (= e n))

De la misma manera, cambiamos foldl por m:foldl. La definición es un poquito más complicada. Además en vez de una variable intermedia crea un montón de variables intermedias que por suerte desaparecen (de forma similar a las proc-val que había en m:ormap).





(define-syntax m:foldl
  (syntax-rules (list)
    [(m:foldl proc ini (list)) ini]
    [(m:foldl proc ini (list x y ...)) (let ([proc-val proc])
                                         (m:foldl proc-val (proc x ini) (list y ...)))]))


Acá hay otro problema técnico si usamos m:fold. (que es una macro) en to-number 
(que es una función). Resulta que to-number se inlinea después de la expansión 
de m:fold. Así que m:fold no ve la lista explicita sino que ve el argumento de to-number. En general hay que tener cuidado porque las macros que parecen 
funciones a veces no se componen mágicamente con las funciones de verdad. Por 
eso tenemos que armar también la macro m:to-number. 
(define-syntax-rule (m:to-number digits)
  (m:foldl {lambda (x y) (+ (* 10 y) x)} 0 digits))

Ahora primero se expande m:to-number
(m:to-number (list s e n d))
en
(m:foldl {lambda (x y) (+ (* 10 y) x)} 0 (list s e n d))
sin romper la lista explicita y entonces m:fold puede ver la lista explícita cuando se expande.

Tiempo de ejecución promedio (5 corridas, en milisegundos):

CPUGC CPUGC
#:when1766,040,8 6071,8187,4
macros271,40,0 951,40,0

¡El tiempo de ejecución bajó un 85%, o sea un x6! El tiempo de GC bajó a 0. El tiempo de GC era poco, pero es un buen síntoma de que no se crean listas o closures intermedias.

La verdad es que este uso de las macros es medio lío, sobre todo porque hay que pensar bien como se componen con las funciones. Pero un x6 vale la pena en algunos for muy usados y benchmarks. Para la mayor parte del código me parece que es mejor usar funciones, y quedarse tranquilo de que manejan correctamente los casos difíciles.

Sería bueno que esta optimización se aplicara automáticamente. A ojo, la optimización de
(ormap display (list 1 2 3))
a
(ormap (display 1) (display 2) (display 3))
no parece muy difícil. Pero cuando el compilador de Racket puede actuar ormap ya está inlineada y la estructura que ve el optimizador es más complicada. Es una de las optimizaciones que tengo en mi lista de deseos y creo que en algún momento se puede agregar a Racket.


(define-syntax m:ormap
  (syntax-rules (list)
    [(m:ormap proc (list n ...)) (let ([proc-val proc])
                                   (or (proc-val n) ...))]))

(define-syntax m:foldl
  (syntax-rules (list)
    [(m:foldl proc ini (list)) ini]
    [(m:foldl proc ini (list x y ...)) (let ([proc-val proc])
                                         (m:foldl proc-val (proc x ini) (list y ...)))]))

(define-syntax-rule (m:to-number digits)
  (m:foldl {lambda (x y) (+ (* 10 y) x)} 0 digits))

(time (for*/list ([s (in-list digits)]
                  #:when (not (m:ormap {lambda (x) (= x s)} (list 0)))
                  [e (in-list digits)]
                  #:when (not (m:ormap {lambda (x) (= x e)} (list s)))
                  [n (in-list digits)]
                  #:when (not (m:ormap {lambda (x) (= x n)} (list s e)))
                  [d (in-list digits)]
                  #:when (not (m:ormap {lambda (x) (= x d)} (list s e n)))
                  [send (in-value (m:to-number (list s e n d)))]
                  [m (in-list digits)]
                  #:when (not (m:ormap {lambda (x) (= x m)} (list 0 s e n d)))
                  [o (in-list digits)]
                  #:when (not (m:ormap {lambda (x) (= x o)} (list s e n d m)))
                  [r (in-list digits)]
                  #:when (not (m:ormap {lambda (x) (= x r)} (list s e n d m o)))
                  [more (in-value (m:to-number (list m o r e)))]
                  [y (in-list digits)]
                  #:when (not (m:ormap {lambda (x) (= x y)} (list s e n d m o r)))
                  [money (in-value (m:to-number (list m o n e y)))]
                  #:when (= (+ send more) money))
        (list send more money)))

Usando funciones unsafe

Esta es una transformación sencilla y quizás se podría haber aplicado antes. El problema es que transforma código que genera errores en código erroneo o que cuelga el programa. Así que mejor reservarla para casos extremos. La idea es reemplazar las operaciones de enteros por su versión unsafe. Notar que in-list chequea primero que el argumento realmente sea una lista, y después usa internamente las versiones unsafe, así que no me parece necesario crear un unsafe-in-list. (Además es complicado crear este tipo de generadores de secuencias. Da para otro artículo completo.)

Tiempo de ejecución promedio (5 corridas, en milisegundos):

CPUGC CPUGC
macros271,40,0 951,40,0
unsafe234,00,0 833,20,0
El tiempo de ejecución bajó sólo un 15%. No está mal mientras no nos equivoquemos y terminemos colgando el programa.


(require racket/unsafe/ops)

(define-syntax-rule (m:unsafe-fx-to-number digits)
  (m:foldl {lambda (x y) (unsafe-fx+ (unsafe-fx* 10 y) x)} 0 digits))

(time (for*/list ([s (in-list digits)]
                  #:when (not (m:ormap {lambda (x) (unsafe-fx= x s)} (list 0)))
                  [e (in-list digits)]
                  #:when (not (m:ormap {lambda (x) (unsafe-fx= x e)} (list s)))
                  [n (in-list digits)]
                  #:when (not (m:ormap {lambda (x) (unsafe-fx= x n)} (list s e)))
                  [d (in-list digits)]
                  #:when (not (m:ormap {lambda (x) (unsafe-fx= x d)} (list s e n)))
                  [send (in-value (m:unsafe-fx-to-number (list s e n d)))]
                  [m (in-list digits)]
                  #:when (not (m:ormap {lambda (x) (unsafe-fx= x m)} (list 0 s e n d)))
                  [o (in-list digits)]
                  #:when (not (m:ormap {lambda (x) (unsafe-fx= x o)} (list s e n d m)))
                  [r (in-list digits)]
                  #:when (not (m:ormap {lambda (x) (unsafe-fx= x r)} (list s e n d m o)))
                  [more (in-value (m:unsafe-fx-to-number (list m o r e)))]
                  [y (in-list digits)]
                  #:when (not (m:ormap {lambda (x) (unsafe-fx= x y)} (list s e n d m o r)))
                  [money (in-value (m:unsafe-fx-to-number (list m o n e y)))]
                  #:when (unsafe-fx= (unsafe-fx+ send more) money))
        (list send more money)))

Usando in-range en vez de in-list

La idea es que en vez de usar (in-list digits) podemos usar (in-range 10). En el for, la cláusula (in-range 10) no crea una secuencia sino que tiene las instrucciones para recorrer los 10 números. Así que nuevamente no se crea una lista o secuencia innecesaria. Esto parece más rápido ya que es sólo cuestión de sumar números que están en el micro o en el cache, en vez de buscar los elementos de la lista en cualquier posición de memoria.

Tiempo de ejecución promedio (5 corridas, en milisegundos):

CPUGC CPUGC
unsafe234,00,0 833,20,0
in-range231,00,0 764,20,0

El tiempo de ejecución bajó sólo un 1%-10%, al borde del error de medición en la computadora de escritorio. Es menos de lo que yo esperaba.


(time (for*/list ([s (in-range 10)]
                  #:when (not (m:ormap {lambda (x) (unsafe-fx= x s)} (list 0)))
                  [e (in-range 10)]
                  #:when (not (m:ormap {lambda (x) (unsafe-fx= x e)} (list s)))
                  [n (in-range 10)]
                  #:when (not (m:ormap {lambda (x) (unsafe-fx= x n)} (list s e)))
                  [d (in-range 10)]
                  #:when (not (m:ormap {lambda (x) (unsafe-fx= x d)} (list s e n)))
                  [send (in-value (m:unsafe-fx-to-number (list s e n d)))]
                  [m (in-range 10)]
                  #:when (not (m:ormap {lambda (x) (unsafe-fx= x m)} (list 0 s e n d)))
                  [o (in-range 10)]
                  #:when (not (m:ormap {lambda (x) (unsafe-fx= x o)} (list s e n d m)))
                  [r (in-range 10)]
                  #:when (not (m:ormap {lambda (x) (unsafe-fx= x r)} (list s e n d m o)))
                  [more (in-value (m:unsafe-fx-to-number (list m o r e)))]
                  [y (in-range 10)]
                  #:when (not (m:ormap {lambda (x) (unsafe-fx= x y)} (list s e n d m o r)))
                  [money (in-value (m:unsafe-fx-to-number (list m o n e y)))]
                  #:when (unsafe-fx= (unsafe-fx+ send more) money))
        (list send more money)))

Conclusiones

Comparemos de vuelta todos los tiempos de ejecución:

Tiempo de ejecución promedio (5 corridas, en milisegundos):

CPUGC CPUGC
literal2648,858,8 11051,0174,6
in-list2377,415,4 9185,462,6
=1507,028,2 5559,862,8
#:when1766,040,8 6071,8187,4
macros271,40,0 951,40,0
unsafe234,00,0 833,20,0
in-range231,00,0 764,20,0


Para que sean más fáciles de comparar hagamos dos gráficos. En un gráfico están los tiempos reales. En el otro los tiempos amuchados para que todos sean visibles. Dividimos a los tiempos de la netbook por 4 y multiplicamos los tiempos del GC por 10 (O sea que el tiempo del GC en la netbook está multiplicado por 10/4.) Es notable que se pierde poco tiempo con el GC.

Tiempos reales (misma escala)

Tiempos amuchados (diferentes escalas)


En cada paso hicimos transformaciones pequeñas, que nos permitieron llegar a una versión 12 o 14 veces más rápida del programa. La única parte complicada fue armar las macros m:ormap y m:fold que son reutilizables (y hay que recordar armar también m:to-number). Con estas macros podemos hacer todo el tanteo sin utilizar listas y así lograr que el programa vaya más rápido.

El problema es que el nuevo programa es un poco más difícil de mantener, como se hizo evidente por tener que armar la macro m:to-number. Así que mejor reservar las macros especiales para ciclos muy usados y benchmarks. (O para los casos en que realmente se necesita una macro, por supuesto.)

Para extender m:ormap

Para poder usar m:ormap como un reemplazo de ormap en cualquier momento sin preocuparse, habría que resolver algunos detalles:
  • Estaría bueno agregarle chequeo de errores a la macro m:ormap. Cuando uno empieza a componer macros, si cada una de ellas no tienen buenos chequeos de errores, los errores se detectan muy tarde. En ese momento el código ya sufrió varias transformaciones y los mensajes de error son muy crípticos.
  • Mejor todavía, sería bueno que m:ormap se transforme en ormap si no encuentra una lista explicita, así la podemos aplicar siempre.
  • Además estaría bueno agregarle que detecte los casos '(), '(1 2 3), null, (cons x (list ...)), y si uno es valiente quizás (list* ...) y (append ...)
  • Hay que tener cuidado con casos extraños con listas cuyos elementos tienen efectos secundarios, como
    (m:ormap {lambda (x) (display 1)} (list (display 2) (error 'error)) (values 3 4) (display 5))
    que antes de generar un error debería mostrar solamente
    2 en vez de 21 o 251 o ...
    Creo que se puede arreglar, pero habría que alargar mucho la macro.
  • Ya que estamos, se puede agregar algunos detalles como usar syntax-property de 'disappeared-use para que DrRacket le ponga la flechita a list.