root/exactrbn/trunk/src/rbn_solve.h

Revision 670, 16.0 kB (checked in by tapted, 2 years ago)

Remove confusing debug vector goo

  • Property svn:eol-style set to native
  • Property svn:keywords set to author date id revision url Rev Revision
Line 
1/* $Id$ $URL$ */
2#ifndef RBN_SOLVE_DOT_AITCH
3#define RBN_SOLVE_DOT_AITCH
4
5/**\file rbn_solve.h
6 * Template classes MyVec, Example and ExactRBFNet for implementing
7 * and solving exact-interpolation radial basis function neural networks.
8 * Copyright Trent Apted 2008 All rights reserved.
9 *
10 * \author Trent Apted <tapted@it.usyd.edu.au>
11 * \version $Revision$
12 * \date $Date$
13 *
14 * Macros to influence behaviour (if defined before # include):
15 * - # define RBFNET_READ_ONLY // make a "read only" network that must be read from disk
16 *                              (removes dependencies on lapack)
17 * - # define RBFNET_DEBUG     // prints debugging information to stderr
18 * - # define RBFNET_VDEBUG    // prints more debugging information to stderr
19 * \n\n
20 */
21
22#include <vector>
23namespace stl = std;
24
25#include <unistd.h>
26#include <valarray>
27#include <istream>
28#include <ostream>
29#include <fstream>
30#include <typeinfo>
31
32/**
33 * Simple N-dimensional vector of type T.
34 * \param T the type of elements in the vector
35 * \param DIM the dimension (size) of the vector
36 */
37template <class T, int DIM>
38    struct MyVec {
39        T v[DIM]; ///< The components
40        /** Indexation */
41        T& operator[](size_t i) {return v[i];}
42        /** Const indexation */
43        const T& operator[](size_t i) const {return v[i];}
44    };
45
46/**
47 * A "training" example -- an instance of a function mapping a value
48 * from the domain of T^INDIM to the range T^OUTDIM.
49 * \param T the type of element (e.g. float, double, complex)
50 * \param INDIM the dimension of the domain
51 * \param OUTDIM the dimension of the range
52 */
53template <class T, int INDIM, int OUTDIM>
54struct Example {
55    MyVec<T, INDIM> in;   ///< The input vector
56    MyVec<T, OUTDIM> out; ///< The resulting output vector
57};
58
59/** Gaussian function */
60static inline double gauss(const double sigma, const double distsq)
61;//__attribute__((const, fastcall, nothrow, unused));
62
63/** gaussian distance for floats */
64static inline float gauss(const float sigma, const float distsq)
65;//__attribute__((const, fastcall, nothrow, unused));
66
67/** Gaussian "distance" between two points -- float version */
68template <class T, int DIM>
69static inline T gdist(const T *const v1, const T *const v2, const T sigma)
70;//__attribute__((regparm(3), nonnull, nothrow, pure, unused));
71
72
73/** Gaussian "distance" between two points -- double version */
74template <class T, int DIM>
75static inline T dgdist(const T *const v1, const T *const v2, const T sigma)
76;//__attribute__((regparm(3), nonnull, nothrow, pure, unused));
77
78/** The euler distance between two poitns */
79template <class T, int DIM>
80static inline T distsq(const T* const v1, const T *const v2) {
81    T sum = 0;
82    for (unsigned i = 0; i < DIM; ++i) {
83        const T d = v1[i] - v2[i];
84        sum += d*d;
85    }
86    return sum;
87}
88
89/** Gaussian distribution functor */
90template <class T, int DIM>
91struct GDist {
92    /** Calls gdist() */
93    T operator() (const T *const v1, const T *const v2, const T sigma) const {
94        return gdist<T, DIM>(v1, v2, sigma);
95    }
96};
97
98/** Euler/linear distribution functor */
99template <class T, int DIM>
100struct EulerDistSq {
101    /** Calls sqrt() of distsq() */
102    T operator() (const T *const v1, const T *const v2, const T /*sigma*/) const {
103        return sqrt(distsq<T, DIM>(v1, v2));
104    }
105};
106
107
108/**
109 * \brief An Exact-interpolation radial basis function (RBF) neural network (NN).
110 *
111 * The interpolation task:
112 * <ul>
113 *   <li>Given</li>
114 *   <ol>
115 *     <li>A black box with \f$n\f$ inputs \f$x_1, x_2, \dots, x_n\f$ and
116 *        an \f$m\f$ outputs \f$y_1, y_2, \dots, y_m\f$</li>
117 *     <li>A database of \f$k\f$ sample input-output \f$(\boldsymbol{x}, \boldsymbol{y}(\boldsymbol{x}))\f$
118 *        combinations (training examples)</li>
119 *   </ol>
120 *   <li>Goal: Create a NN that predicts the output given an input vector \f$\boldsymbol{x}\f$</li>
121 *   <ul>
122 *     <li>if \f$\boldsymbol{x}\f$ is a vector from the sample data, it must
123 *       produce the exact output of the black box</li>
124 *     <li>otherwise it must produce output clost to the output of the black
125 *        box (i.e. a result from the [most likely non-linear] interpolation of
126 *        the training examples.)</li>
127 *   </ul>
128 * </ul>
129 *
130 * \author Trent Apted <tapted@it.usyd.edu.au>
131 *
132 * \param T The type of the floating point number to use. Currently \a double and
133 *          \a float are supported. \a complex is also a possibility.
134 * \param INDIM The dimension of the domain (inputs)
135 * \param OUTDIM The dimension of the range (outputs)
136 */
137template <class T, int INDIM, int OUTDIM, class DISTFUNC = GDist<T, INDIM> >
138class ExactRBFNet {
139public:
140    typedef MyVec<T, INDIM> INVEC;   ///< An input vector -- \f$\boldsymbol{x}\f$
141    typedef MyVec<T, OUTDIM> OUTVEC; ///< An output vector -- \f$\boldsymbol{y}(\boldsymbol{x})\f$
142    typedef Example<T, INDIM, OUTDIM> EXAMPLE; ///< An example \f$(\boldsymbol{x}, \boldsymbol{y}(\boldsymbol{x}))\f$
143    typedef stl::vector<EXAMPLE> EXAMPLEVEC;   ///< A collection of examples/centres
144#ifndef RBFNET_DEBUG
145protected:
146#endif
147    EXAMPLEVEC examples;              ///< The examples given AND the "centres" of the gaussians
148    std::valarray<T> weights[OUTDIM]; ///< The weights vector -- \f$\boldsymbol{w}\f$
149    T sigma;                          ///< The \f$\sigma\f$ value to use for the gaussian function
150    DISTFUNC distfunc;                ///< The distance function
151    bool badbit;                      ///< Set if/when there are any read errors
152public:
153    /** Construct a (previously output) ExactRBFNet from an input stream. */
154    ExactRBFNet(std::istream& in, const DISTFUNC &df = DISTFUNC()) : sigma(0), distfunc(df), badbit(false) { input(in); }
155
156    /** Construct a (previously output) ExactRBFNet from a file. */
157    ExactRBFNet(const char* file, const DISTFUNC &df = DISTFUNC()) : sigma(0), distfunc(df), badbit(false) { std::ifstream f(file); input(f); }
158
159#ifndef RBFNET_READ_ONLY
160    /**
161     * Construct a new ExactRBFNet with a given value for \f$\sigma\f$. \f$\sigma\f$ is the
162     * width / influence of the Gaussian function. A larger \f$\sigma\f$ means that each
163     * example will have a greater influence over the domain -- each sub() will be influenced
164     * by a greater number of training examples.
165     *
166     * \param sig \a sigma, \f$\sigma\f$
167     * \param df The distance function to use
168     */
169    ExactRBFNet(const T& sig = 4, const DISTFUNC &df = DISTFUNC()) : sigma(sig), distfunc(df), badbit(false) {}
170
171    /**
172     * Add a training example / centre.
173     */
174    void addExample(const EXAMPLE &ex);
175
176    /**
177     * \brief Solve this RBF Network for the weights.
178     *
179     * That is, solve the set of linear equations:
180     *
181     * \f{eqnarray*}
182     *   y(\boldsymbol{x}_1) & = & w_1\mathrm{e}^\frac{\scriptstyle \|\boldsymbol{x}_1-\boldsymbol{x}_1\|^2}{\scriptstyle 2\sigma} +
183     *                             w_2\mathrm{e}^\frac{\scriptstyle \|\boldsymbol{x}_1-\boldsymbol{x}_2\|^2}{\scriptstyle 2\sigma} + \cdots +
184     *                             w_k\mathrm{e}^\frac{\scriptstyle \|\boldsymbol{x}_1-\boldsymbol{x}_k\|^2}{\scriptstyle 2\sigma} \\
185     *   y(\boldsymbol{x}_2) & = & w_1\mathrm{e}^\frac{\scriptstyle \|\boldsymbol{x}_2-\boldsymbol{x}_1\|^2}{\scriptstyle 2\sigma} +
186     *                             w_2\mathrm{e}^\frac{\scriptstyle \|\boldsymbol{x}_2-\boldsymbol{x}_2\|^2}{\scriptstyle 2\sigma} + \cdots +
187     *                             w_k\mathrm{e}^\frac{\scriptstyle \|\boldsymbol{x}_2-\boldsymbol{x}_k\|^2}{\scriptstyle 2\sigma} \\
188     *    & \vdots & \\
189     *   y(\boldsymbol{x}_k) & = & w_1\mathrm{e}^\frac{\scriptstyle \|\boldsymbol{x}_k-\boldsymbol{x}_1\|^2}{\scriptstyle 2\sigma} +
190     *                             w_2\mathrm{e}^\frac{\scriptstyle \|\boldsymbol{x}_k-\boldsymbol{x}_2\|^2}{\scriptstyle 2\sigma} + \cdots +
191     *                             w_k\mathrm{e}^\frac{\scriptstyle \|\boldsymbol{x}_k-\boldsymbol{x}_k\|^2}{\scriptstyle 2\sigma}
192     * \f}
193     * for \f$\boldsymbol{w}\f$ for each of the \f$\boldsymbol{m}\f$ outputs (i.e. for each set of \f$y\f$ s).
194     *
195     * This needs to be called before sub() if there have been any calls to addExample().
196     *
197     * If RBFNET_DEBUG is defined, the matrix passed to lapack is printed on stderr.
198     * If RBFNET_VDEBUG is defined, the formula substitutions are also printed.
199     */
200    int solve();
201
202    /**
203     * Set a new value of \f$\sigma\f$. Afterwards, solve() needs to be called again.
204     */
205    void setSigma(const T &newsig) {sigma = newsig;}
206#endif
207    /**
208     * Return the number of examples == the dimension of each weight vector, \f$\boldsymbol{w}\f$.
209     */
210    size_t dim() {return examples.size();}
211
212    /**
213     * Substitute an input into the NN, for the interpolated output.
214     * \param out The output that is generated (assigned by reference)
215     * \param in The input to present
216     */
217    void sub(OUTVEC &out, const INVEC &in) const ;
218    /**
219     * A debugging version of sub(), that shows the formula substitution on
220     * sterr if RBFNET_VDEBUG is defined.
221     */
222    void dsub(OUTVEC &out, const INVEC &in) const ;
223
224    /**
225     * Output a human and machine readable representation of the network to \a o.
226     */
227    std::ostream &output(std::ostream& o) const;
228
229    /**
230     * Read in a previously output() representation of the network from \a in.
231     */
232    std::istream &input(std::istream& i);
233
234    /**
235     * Are we ready for use (i.e. no read errors from the fstream constructor)
236     */
237    bool ready() { return !badbit; }
238};
239
240/** Returns rhs.output(o) */
241template <class T, int INDIM, int OUTDIM, class DISTFUNC>
242std::ostream& operator<< (std::ostream &o, const ExactRBFNet<T, INDIM, OUTDIM, DISTFUNC> &rhs) {
243    return rhs.output(o);
244}
245
246/** Returns rhs.input(i) */
247template <class T, int INDIM, int OUTDIM, class DISTFUNC>
248std::istream& operator>> (std::istream &i, const ExactRBFNet<T, INDIM, OUTDIM, DISTFUNC> &rhs) {
249    return rhs.input(i);
250}
251
252/**
253 * A template instantiation for 2D to 2D interpolation (e.g. calibrating a
254 * projected touch display).
255 */
256typedef ExactRBFNet<double, 2, 2> XYInterp;
257
258/**
259 * A template instantiation of a linear interpolation network (works poorly)
260 */
261typedef ExactRBFNet<double, 2, 2, EulerDistSq<double, 2> > XYLinInterp;
262
263
264//IMPLEMENTATION
265#ifndef RBFNET_READ_ONLY
266#include "lapack_leqn.h"
267#endif
268
269#include <stdio.h>
270#include <math.h>
271
272static inline double gauss(const double sigma, const double distsq) {
273    return exp((distsq)/(2.0*sigma));
274}
275
276static inline float gauss(const float sigma, const float distsq) {
277    return expf((distsq)/(2.0*sigma));
278}
279
280template <class T, int DIM>
281static inline T gdist(const T *const v1, const T *const v2, const T sigma) {
282    return gauss(sigma, distsq<T, DIM>(v1, v2));
283}
284
285#ifndef RBFNET_READ_ONLY
286template <class T, int INDIM, int OUTDIM, class DISTFUNC>
287void ExactRBFNet<T, INDIM, OUTDIM, DISTFUNC>::addExample(const EXAMPLE &ex) {
288    examples.push_back(ex);
289}
290
291template <class T, int INDIM, int OUTDIM, class DISTFUNC>
292int ExactRBFNet<T, INDIM, OUTDIM, DISTFUNC>::solve() {
293    int sz = examples.size();
294    int oks = 0;
295    T *A = new T[sz*sz];
296    T *R = new T[sz];
297    for (int r = 0; r < sz; ++r) {
298#ifdef RBFNET_VDEBUG
299        fprintf(stderr, "\ny = ");
300#endif
301        int c = 0;
302#ifdef RBFNET_EXPONE_OPTIMIZATION
303        for (; c < r; ++c) {
304            A[r*sz + c] = distfunc(&examples[r].in[0], &examples[c].in[0], sigma);
305        }
306        A[r*sz + c++] = 1.0; //e^0 == 1
307#endif
308        for (; c < sz; ++c) {
309            A[r*sz + c] = distfunc(&examples[r].in[0], &examples[c].in[0], sigma);
310        }
311    }
312    for (int o = 0; o < OUTDIM; ++o) {
313        for (int r = 0; r < sz; ++r)
314            R[r] = examples[r].out[o];
315#ifdef RBFNET_DEBUG
316        fprintf(stderr, "\nSolving :\n");
317        for (int _r = 0; _r < sz; ++_r) {
318            fprintf(stderr, "[ %8f", A[_r*sz]);
319            for (int _c = 1; _c < sz; ++ _c) {
320                fprintf(stderr, ", %8f", A[_r*sz + _c]);
321            }
322            fprintf(stderr, "][%c] = [%8f]\n", 'a'+_r, R[_r]);
323        }
324#endif
325        oks += lapack_solve(R, A, sz);
326#ifdef RBFNET_DEBUG
327        if (oks) {
328            fprintf(stderr, "\nNo Solution\n");
329            oks = 0;
330        } else {
331            fputc('\n', stderr);
332            for (int _r = 0; _r < sz; ++_r) {
333                fprintf(stderr, "[%c] = [%8f]\n", 'a'+_r, R[_r]);
334            }
335        }
336#endif
337        weights[o].resize(sz);
338        //weights[o] = std::valarray<T>(R, sz);
339        for (int r = 0; r < sz; ++r)
340            weights[o][r] = R[r];
341    }
342    delete[] R;
343    delete[] A;
344    return oks;
345}
346#endif //#ifndef RBFNET_READ_ONLY
347
348template <class T, int INDIM, int OUTDIM, class DISTFUNC>
349void ExactRBFNet<T, INDIM, OUTDIM, DISTFUNC>::sub(OUTVEC &out, const INVEC &in) const {
350    int sz = examples.size();
351    for (int o = 0; o < OUTDIM; ++o) {
352        out.v[o] = weights[o][0] * distfunc(&in[0], &examples[0].in[0], sigma);
353        for (int i = 1; i < sz; ++i) {
354            out.v[o] += weights[o][i] * distfunc(&in[0], &examples[i].in[0], sigma);
355        }
356    }
357}
358
359template <class T, int INDIM, int OUTDIM, class DISTFUNC>
360void ExactRBFNet<T, INDIM, OUTDIM, DISTFUNC>::dsub(OUTVEC &out, const INVEC &in) const {
361    int sz = examples.size();
362    for (int o = 0; o < OUTDIM; ++o) {
363#ifdef RBFNET_VDEBUG
364        fprintf(stderr, "\ny = (%f)", weights[o][0]);
365#endif
366        out.v[o] = weights[o][0] * distfunc(&in[0], &examples[0].in[0], sigma);
367        for (int i = 1; i < sz; ++i) {
368#ifdef RBFNET_VDEBUG
369            fprintf(stderr, "(%f)", weights[o][i]);
370#endif
371            out.v[o] += weights[o][i] * distfunc(&in[0], &examples[i].in[0], sigma);
372        }
373    }
374}
375
376template <class T, int DIM>
377static inline T dgdist(const T *const v1, const T *const v2, const T sigma) {
378    T sum = 0;
379    for (unsigned i = 0; i < DIM; ++i) {
380        const T d = v1[i] - v2[i];
381        sum += d*d;
382    }
383#ifdef RBFNET_VDEBUG
384    fprintf(stderr, "we^(%f/2*%f) ", sum, sigma);
385#endif
386    return gauss(sigma, sum);
387}
388
389/** Output a MyVec */
390template <class T, int DIM>
391std::ostream &operator<< (std::ostream &o, const MyVec<T, DIM> &v) {
392    o << v[0];
393    for (int i = 1; i < DIM; ++i)
394        o << ' ' << v[i];
395    return o;
396}
397
398/** Input a MyVec */
399template <class T, int DIM>
400std::istream &operator>> (std::istream &i, MyVec<T, DIM> &v) {
401    i >> v[0];
402    for (int j = 1; j < DIM; ++j)
403        i >> v[j];
404    return i;
405}
406
407/** Output an Example */
408template <class T, int INDIM, int OUTDIM>
409std::ostream &operator<< (std::ostream &o, const Example<T, INDIM, OUTDIM> &e) {
410    return o << e.in << "\n" << e.out;
411}
412
413/** Input an Example */
414template <class T, int INDIM, int OUTDIM>
415std::istream &operator>> (std::istream &i, Example<T, INDIM, OUTDIM> &e) {
416    return i >> e.in >> e.out;
417}
418
419template <class T, int INDIM, int OUTDIM, class DISTFUNC>
420std::ostream &ExactRBFNet<T, INDIM, OUTDIM, DISTFUNC>::output(std::ostream& o) const {
421    o.flags(std::ios_base::scientific);
422    o.precision(15);
423    o << "ExactRBFNet< " << typeid(T).name() << " , " << INDIM << " , " << OUTDIM << " >\n"
424        << "sigma= " << sigma << "\n";
425    o << "weights==examples.size()= " << examples.size();
426    for (int p = 0; p < OUTDIM; ++p) {
427        o << "\nweights[" << p << "]= ";
428        for (size_t i = 0; i < weights[p].size(); ++i)
429            o << weights[p][i] << " ";
430    }
431    for (typename EXAMPLEVEC::const_iterator it = examples.begin(); it != examples.end(); ++it) {
432        o << "\n\n" << *it;
433    }
434    return o;
435}
436
437template <class T, int INDIM, int OUTDIM, class DISTFUNC>
438std::istream &ExactRBFNet<T, INDIM, OUTDIM, DISTFUNC>::input(std::istream& i) {
439    std::string d, type;
440    int indim, outdim, size;
441    i >> d >> type >> d >> indim >> d >> outdim >> d >> d >> sigma;
442    i >> d >> size;
443    if (i) {
444        for (int o = 0; i && o < OUTDIM; ++o) {
445            i >> d;
446            weights[o].resize(size);
447            for (int j = 0; i && j < size; ++j)
448                i >> weights[o][j];
449        }
450        examples.reserve(size);
451        for (int j = 0; i && j < size; ++j) {
452            EXAMPLE eg;
453            i >> eg;
454            examples.push_back(eg);
455        }
456    }
457    if (!i)
458        badbit = true;
459    return i;
460}
461
462#endif //#ifndef RBN_SOLVE_DOT_AITCH
Note: See TracBrowser for help on using the browser.