Matrix.h

00001 // ----------------------------------------------------------------------------
00002 //
00003 // tagGame - Example code from the book:
00004 //
00005 //           Artficial Intelligence for Computer Games: An Introduction
00006 //           by John David Funge
00007 //
00008 //           www.ai4games.org
00009 //
00010 // Source code distributed under the Copyright (c) 2003-2007, John David Funge
00011 // Original author: John David Funge (www.jfunge.com)
00012 //
00013 // Licensed under the Academic Free License version 3.0 
00014 // (for details see LICENSE.txt in this directory).
00015 //
00016 // ----------------------------------------------------------------------------
00017 
00018 #ifndef TG_MATRIX_H
00019 #define TG_MATRIX_H
00020 
00021 #include "Vec.h"
00022 
00023 namespace tagGame
00024 {
00027    template <class T>
00028    class Matrix : public std::vector< Vec<T> >
00029    {
00030    public:
00031       Matrix(size_t const rowCount = 0, size_t const colCount = 0);
00032       Matrix(std::vector< Vec<T> > const& rows);
00033       Matrix(Matrix const& m);
00034       ~Matrix();
00035 
00036       size_t getRowCount() const;
00037       size_t getColCount() const;
00038       void setRowColCount(size_t const rowCount, size_t const colCount);
00039       Vec<T> const& getRow(size_t const i) const;
00040       Matrix& setRow(size_t const i, Vec<T> const& v);
00041       Vec<T> const& getCol(size_t const j) const;
00042       Matrix& setCol(size_t const j, Vec<T> const& v);
00043       std::vector< Vec<T> > const& getRows() const;
00044       std::vector< Vec<T> > const& getCols() const;
00045 
00046       Matrix& operator=(Matrix const& m);
00047 
00048       bool operator==(Matrix const& m) const;
00049       bool isAlmostEq(Matrix const& m) const;
00050       bool isAlmostZero() const;
00051 
00052       Matrix& set(T const x);
00053       Matrix& add(Matrix const& m);
00054       Matrix& subtract(Matrix const& m);
00055       Matrix& scale(T const x);
00056       Matrix& multiply(Matrix const& m);
00057       Matrix& transpose();
00058       Matrix& invert();
00059 
00060       Vec<T> multiply(Vec<T> const& v) const;
00061 
00062       T minElement() const;
00063       T maxElement() const;
00064       size_t argMax() const;
00065       size_t argMin() const;
00066 
00067       Matrix& shuffle();
00068       Matrix& randomize();
00069    protected:
00070    private:
00071       void computeCol(size_t const j);
00072 
00073       // whenever column information is computed it is cached so that it can be retrieved
00074       // quickly whenever it is still valid
00075       mutable std::vector< Vec<T> > cols;
00076       mutable std::vector<bool> colsValid;
00077    };
00078 
00079    typedef Matrix<int> IntMatrix;
00080    typedef Matrix<Real> RealMatrix;
00081 
00082    template <class T>
00083    Matrix<T>::Matrix(size_t const rowCount, size_t const colCount) :
00084       std::vector< Vec<T> >(rowCount),
00085       cols(colCount),
00086       colsValid(colCount)
00087    {
00088       for (size_t i = 0; i < rowCount; i++)
00089       {
00090          (*this)[i].resize(colCount);
00091          // Note: Vec sets to 0
00092       }
00093       for (size_t j = 0; j < colCount; j++)
00094       {
00095          cols[j].resize(rowCount);
00096       }
00097    
00098       fill(colsValid.begin(), colsValid.end(), true);
00099    }
00100    
00101    template <class T>
00102    Matrix<T>::Matrix(std::vector< Vec<T> > const& rows) :
00103       std::vector< Vec<T> >::vector(rows)
00104    {
00105       int colCountPrev(0);
00106       for (size_t i = 0; i < std::vector<T>::size(); i++)
00107       {
00108          int const n((*this)[i].size());
00109          TG_ASSERT(colCountPrev < 0 || colCountPrev == n);
00110          colCountPrev = n;
00111       }
00112       cols.resize(colCountPrev);
00113       for (size_t j = 0; j < colCountPrev; j++)
00114       {
00115          cols[j].resize(rows.size());
00116       }
00117       colsValid.resize(colCountPrev);
00118       fill(colsValid.begin(), colsValid.end(), false);
00119    }
00120    
00121    template <class T>
00122    Matrix<T>::Matrix(Matrix const& m) :
00123       std::vector< Vec<T> >(m),
00124       cols(m.cols),
00125       colsValid(m.colsValid)
00126    {
00127    }
00128    
00129    template <class T>
00130    Matrix<T>::~Matrix()
00131    {
00132    }
00133    
00134    template <class T>
00135    size_t Matrix<T>::getRowCount() const
00136    {
00137       return this->size();
00138    }
00139    
00140    template <class T>
00141    size_t Matrix<T>::getColCount() const
00142    {
00143       return cols.size();
00144    }
00145    
00146    template <class T>
00147    void Matrix<T>::setRowColCount(size_t const rowCount, size_t const colCount)
00148    {
00149       std::vector<T>::resize(rowCount);
00150       for (size_t i = 0; i < rowCount; i++)
00151       {
00152          (*this)[i].resize(colCount);
00153          // Note: Vec sets to 0
00154       }
00155       cols.resize(colCount);
00156       for (size_t j = 0; j < colCount; j++)
00157       {
00158          cols[j].resize(rowCount);
00159       }
00160       colsValid.resize(colCount);
00161       fill(colsValid.begin(), colsValid.end(), false);
00162    }
00163    
00164    template <class T>
00165    Vec<T> const&  Matrix<T>::getRow(size_t const i) const
00166    {
00167       TG_ASSERT(0 <= i && i < getRowCount());
00168    
00169       return (*this)[i];
00170    }
00171    
00172    template <class T>
00173    Matrix<T>& Matrix<T>::setRow(size_t const i, Vec<T> const& v)
00174    {
00175       TG_ASSERT(0 <= i && i < getRowCount());
00176       TG_ASSERT((*this)[i].size() == v.size());
00177    
00178       (*this)[i] = v;
00179    
00180       return *this;
00181    }
00182    
00183    template <class T>
00184    Vec<T> const& Matrix<T>::getCol(size_t const j) const
00185    {
00186       TG_ASSERT(0 <= j && j < cols.size());
00187       computeCol(j);
00188    
00189       return cols[j];
00190    }
00191    
00192    template <class T>
00193    Matrix<T>& Matrix<T>::setCol(size_t const j, Vec<T> const& v)
00194    {
00195       TG_ASSERT(0 <= j && j < cols.size());
00196       computeCol(j);
00197       TG_ASSERT(cols[j].size() == v.size());
00198    
00199       TG_ASSERT(false); // not implemented
00200    
00201       return *this;
00202    }
00203    
00204    template <class T>
00205    std::vector< Vec<T> > const& Matrix<T>::getRows() const
00206    {
00207       return static_cast< std::vector< Vec<T> >const&>(*this);
00208    }
00209    
00210    template <class T>
00211    std::vector< Vec<T> > const& Matrix<T>::getCols() const
00212    {
00213       for (size_t j = 0; j < cols.size(); j++)
00214       {
00215          computeCol(j);
00216       }
00217       return cols;
00218    }
00219    
00220    template <class T>
00221    Matrix<T>& Matrix<T>::operator=(Matrix const& m)
00222    {
00223       if (this != &m)
00224       {
00225          static_cast< std::vector< Vec<T> >&>(*this) = static_cast< std::vector< Vec<T> >const&>(m);;
00226          cols = m.cols;
00227          colsValid = m.colsValid;
00228       }
00229       return *this;
00230    }
00231    
00232    // TODO: make the Matrix methods take advantage of <algorithm>,
00233    //       like the Vec class
00234    
00235    template <class T>
00236    bool Matrix<T>::isAlmostEq(Matrix const& m) const
00237    {
00238       for (size_t i = 0; i < this->rows.size(); i++)
00239       {
00240          if (!((*this)[i].isAlmostEq(m[i]))) { return false; }
00241       }
00242       return true;
00243    }
00244    
00245    template <class T>
00246    bool Matrix<T>::isAlmostZero() const
00247    {
00248       for (size_t i = 0; i < this->rows.size(); i++)
00249       {
00250          if (!((*this)[i].isAlmostZero())) { return false; }
00251       }
00252       return true;
00253    }
00254    
00255    template <class T>
00256    Matrix<T>& Matrix<T>::set(T const x)
00257    {
00258       for (size_t i = 0; i < std::vector<T>::size(); i++)
00259       {
00260          (*this)[i].set(0);
00261       }
00262       fill(colsValid.begin(), colsValid.end(), false);
00263    
00264       return *this;
00265    }
00266    
00267    template <class T>
00268    Matrix<T>& Matrix<T>::add(Matrix const& m)
00269    {
00270       for (size_t i = 0; i < std::vector<T>::size(); i++)
00271       {
00272          (*this)[i].add(m[i]);
00273       }
00274       fill(colsValid.begin(), colsValid.end(), false);
00275    
00276       return *this;
00277    }
00278    
00279    template <class T>
00280    Matrix<T>& Matrix<T>::subtract(Matrix const& m)
00281    {
00282       for (size_t i = 0; i < this->rows.size(); i++)
00283       {
00284          (*this)[i].subtract(m[i]);
00285       }
00286       fill(colsValid.begin(), colsValid.end(), false);
00287    
00288       return *this;
00289    }
00290    
00291    template <class T>
00292    Matrix<T>& Matrix<T>::scale(T const x)
00293    {
00294       for (size_t i = 0; i < std::vector<T>::size(); i++)
00295       {
00296          (*this)[i].scale(x);
00297       }
00298       fill(colsValid.begin(), colsValid.end(), false);
00299    
00300       return *this;
00301    }
00302    
00303    template <class T>
00304    Matrix<T>& Matrix<T>::multiply(Matrix const& m)
00305    {
00306       TG_ASSERT(getColCount() == m.getRowCount());
00307       Matrix mOld(*this);
00308       setRowColCount(getRowCount(), m.getColCount());
00309    
00310       for (size_t i = 0; i < getRowCount(); i++)
00311       {
00312          for (size_t j = 0; j < getColCount(); j++)
00313          {
00314             (*this)[i][j] = mOld.getRow(i).dot(m.getCol(j));
00315             cols[j][i] = (*this)[i][j]; // might as well update
00316          }
00317       }
00318       fill(colsValid.begin(), colsValid.end(), true);
00319    
00320       return *this;
00321    }
00322    
00323    template <class T>
00324    Matrix<T>& Matrix<T>::transpose()
00325    {
00326       std::vector< Vec<T> > oldRows(getRows());
00327       static_cast< std::vector< Vec <T> >&>(*this) = getCols();
00328       cols = oldRows;
00329       colsValid.resize(cols.size());
00330    
00331       fill(colsValid.begin(), colsValid.end(), true);
00332    
00333       return *this;
00334    }
00335    
00336    template <class T>
00337    Matrix<T>& Matrix<T>::invert()
00338    {
00339       TG_ASSERT(false); // not implemented
00340    
00341       fill(colsValid.begin(), colsValid.end(), false);
00342    
00343       return *this;
00344    }
00345    
00346    template <class T>
00347    Vec<T> Matrix<T>::multiply(Vec<T> const& v) const
00348    {
00349       TG_ASSERT(getColCount() == v.size());
00350       Vec<T> result(getRowCount());
00351    
00352       for (size_t i = 0; i < getRowCount(); i++)
00353       {
00354          result[i] = (*this)[i].dot(v);
00355       }
00356    
00357       return result;
00358    }
00359    
00360    template <class T>
00361    T Matrix<T>::minElement() const
00362    {
00363       // would like to initialize this to Inf for Real and INT_MAX for int
00364       // could do with appropriate specializations of some function getBottom<T>
00365       T x(T(0));
00366    
00367       for (size_t i = 0; i < std::vector<T>::size(); i++)
00368       {
00369          if (0 == i) { x = (*this)[i].min(); } // unnecessary if initialized properly
00370          else { x = std::min(x, (*this)[i].min()); }
00371       }
00372       return x;
00373    }
00374    
00375    template <class T>
00376    T Matrix<T>::maxElement() const
00377    {
00378       // would like to initialize this to -Inf for Real and INT_MIN for int
00379       // could do with appropriate specializations of some function getBottom<T>
00380       T x(T(0));
00381    
00382       for (size_t i = 0; i < std::vector<T>::size(); i++)
00383       {
00384          if (0 == i) { x = (*this)[i].max(); } // unnecessary if initialized properly
00385          else { x = std::max(x, (*this)[i].max()); }
00386       }
00387       return x;
00388    }
00389    
00390    template <class T>
00391    size_t Matrix<T>::argMax() const
00392    {
00393       Util::error("not implemented");
00394       return 0;
00395    /*
00396       // Note: only returns the row with the maximum element
00397       // TODO: improve the implementation
00398       int which(-1);
00399    
00400       T const x(getMax());
00401    
00402       for (size_t i = 0; i < std::vector<T>::size(); i++)
00403       {
00404          if (x == (*this)[i].max())
00405          {
00406             which = i;
00407             break;
00408          }
00409       }
00410    
00411       return which;
00412    */
00413    }
00414    
00415    template <class T>
00416    size_t Matrix<T>::argMin() const
00417    {
00418       Util::error("not implemented");
00419       return 0;
00420    
00421    /*
00422       int which(-1);
00423    
00424       T const x(getMin());
00425    
00426       for (size_t i = 0; i < std::vector<T>::size(); i++)
00427       {
00428          if (x == (*this)[i].min())
00429          {
00430             which = i;
00431             break;
00432          }
00433       }
00434    
00435       return which;
00436    */
00437    }
00438    
00439    template <class T>
00440    Matrix<T>& Matrix<T>::shuffle()
00441    {
00442       // TODO: shuffle across rows, not just within rows
00443       for (size_t i = 0; i < std::vector<T>::size(); i++)
00444       {
00445          (*this)[i].shuffle();
00446       }
00447    
00448       return *this;
00449    }
00450    
00451    template <class T>
00452    Matrix<T>& Matrix<T>::randomize()
00453    {
00454       for (size_t i = 0; i < std::vector<T>::size(); i++)
00455       {
00456          (*this)[i].randomize();
00457       }
00458    
00459       return *this;
00460    }
00461    
00462    template <class T>
00463    void Matrix<T>::computeCol(size_t const j)
00464    {
00465       if (colsValid[j]) { return; }
00466    
00467       for (size_t i = 0; i < std::vector<T>::size(); i++)
00468       {
00469          cols[j][i] = (*this)[i][j];
00470       }
00471    
00472       colsValid[j] = true;
00473    }
00474 }
00475 
00476 #endif

Generated on Sat Mar 31 22:30:54 2007 for tagGame by  doxygen 1.5.1