Filter.java

/*
 * $Id: Filter.java,v 1.21 2008/07/17 07:30:03 koga Exp $
 * 
 * Copyright (C) 2004 Koga Laboratory. All rights reserved.
 */
package org.mklab.tool.signal;

import java.util.ArrayList;
import java.util.Arrays;
import java.util.List;

import org.mklab.nfc.matrix.ComplexNumericalMatrix;
import org.mklab.nfc.matrix.DoubleMatrix;
import org.mklab.nfc.matrix.RealNumericalMatrix;
import org.mklab.nfc.scalar.ComplexNumericalScalar;
import org.mklab.nfc.scalar.DoubleNumber;
import org.mklab.nfc.scalar.RealNumericalScalar;
import org.mklab.tool.matrix.Makecolv;
import org.mklab.tool.matrix.Makerowv;


/**
 * デジタルフィルタを通した信号を求めるクラスです。
 * 
 * <p> Digital filter
 * 
 * @author koga
 * @version $Revision: 1.21 $
 */
public class Filter {

  /**
   * データ<code>x</code>をフィルタ
   * 
   * <pre><code> y(n) = b(1)*x(n) + b(2)*x(n-1) + ... + b(nb+1)*x(n-nb) - a(2)*y(n-1) - ... - a(na+1)*y(n-na) </code></pre>
   * 
   * に通した値を求めます。
   * 
   * @param b 分子の係数
   * @param a 分母の係数
   * @param x 入力信号
   * @return 出力信号 (filtered signal)
   */
  public static List<DoubleMatrix> filter(DoubleMatrix b, DoubleMatrix a, DoubleMatrix x) {
    int na = a.length() - 1;
    int nb = b.length() - 1;
    // int nx = x.length();
    DoubleMatrix zi = a.createZero(Math.max(na, nb), 1);

    return filter(b, a, x, zi);
  }

  /**
   * 初期条件と終端条件を与える。
   * 
   * @param b_ 分子の係数
   * @param a_ 分母の係数
   * @param x_ 入力信号
   * @param zi_ フィルターの初期状態
   * @return 出力信号 (filtered signal)
   */
  public static List<DoubleMatrix> filter(DoubleMatrix b_, DoubleMatrix a_, DoubleMatrix x_, DoubleMatrix zi_) {
    DoubleMatrix b = Makerowv.makerowv(b_);
    DoubleMatrix a = Makerowv.makerowv(a_);
    DoubleMatrix x = Makerowv.makerowv(x_);

    int na = a.length() - 1;
    int nb = b.length() - 1;
    int nx = x.length();

    DoubleMatrix zi = Makecolv.makecolv(zi_);

    DoubleNumber eps = a.getElement(1, 1).getMachineEpsilon();

    if (na > 0) {
      if (a.getElement(1).abs().isLessThanOrEquals(eps)) {
        System.err.println(Messages.getString("Filter.0")); //$NON-NLS-1$
      }
      a = a.multiply(a.getElement(1).inverse());
      b = b.multiply(a.getElement(1).inverse());
    }

    DoubleMatrix a2t = null;
    if (na > 0) {
      a2t = a.getSubVector(2, a.length()).transpose();
    }
    DoubleMatrix b2t = null;
    if (nb > 0) {
      b2t = b.getSubVector(2, b.length()).transpose();
    }

    DoubleMatrix y = a.createZero(nx, 1);

    DoubleMatrix z = zi;
    DoubleMatrix zab = a.createZero(Math.abs(na - nb), 1);

    if (nb > na) {
      if (nb == 1) {
        for (int i = 1; i <= nx; i++) {
          y.setElement(i, 1, b.getElement(1).multiply(x.getElement(i)).add(z.getElement(1)));
          DoubleMatrix b2t2 = b2t;
          z = b2t2.multiply(x.getElement(i));
        }
      } else {
        if (na == 0) {
          for (int i = 1; i <= nx; i++) {
            y.setElement(i, 1, b.getElement(1).multiply(x.getElement(i)).add(z.getElement(1)));
            DoubleMatrix b2t2 = b2t;
            z = z.getSubMatrix(2, z.getRowSize(), 1, 1).appendDown(a.createZero(1, 1)).add(b2t2.multiply(x.getElement(i)));
          }
        } else {
          for (int i = 1; i <= nx; i++) {
            y.setElement(i, 1, b.getElement(1).multiply(x.getElement(i)).add(z.getElement(1)));
            DoubleMatrix tmp1 = z.getSubMatrix(2, z.getRowSize(), 1, 1).appendDown(a.createZero(1, 1));
            DoubleMatrix b2t2 = b2t;
            DoubleMatrix tmp2 = b2t2.multiply(x.getElement(i));
            DoubleMatrix a2t2 = a2t;
            DoubleMatrix tmp3 = a2t2.multiply(y.getElement(i)).appendDown(zab);
            z = tmp1.add(tmp2).subtract(tmp3);
          }
        }
      }
    } else if (nb == na) {
      if (na == 0) {
        for (int i = 1; i <= nx; i++) {
          y.setElement(i, 1, b.getElement(1).multiply(x.getElement(i)));
        }
      } else if (na == 1) {
        for (int i = 1; i <= nx; i++) {
          y.setElement(i, 1, b.getElement(1).multiply(x.getElement(i)).add(z.getElement(1)));
          DoubleMatrix a2t2 = a2t;
          DoubleMatrix b2t2 = b2t;
          z = b2t2.multiply(x.getElement(i)).subtract(a2t2.multiply(y.getElement(i)));
        }
      } else {
        for (int i = 1; i <= nx; i++) {
          y.setElement(i, 1, b.getElement(1).multiply(x.getElement(i)).add(z.getElement(1)));
          DoubleMatrix tmp1 = z.getSubMatrix(2, z.getRowSize(), 1, 1).appendDown(a.createZero(1, 1));
          DoubleMatrix b2t2 = b2t;
          DoubleMatrix tmp2 = b2t2.multiply(x.getElement(i));
          DoubleMatrix a2t2 = a2t;
          DoubleMatrix tmp3 = a2t2.multiply(y.getElement(i));
          z = tmp1.add(tmp2).subtract(tmp3);
        }
      }
    } else {
      if (na == 1) {
        for (int i = 1; i <= nx; i++) {
          y.setElement(i, 1, b.getElement(1).multiply(x.getElement(i)).add(z.getElement(1)));
          DoubleMatrix a2t2 = a2t;
          z = a2t2.multiply(y.getElement(i)).unaryMinus();
        }
      } else {
        if (nb == 0) {
          for (int i = 1; i <= nx; i++) {
            y.setElement(i, 1, b.getElement(1).multiply(x.getElement(i)).add(z.getElement(1)));
            DoubleMatrix tmp1 = z.getSubMatrix(2, z.getRowSize(), 1, 1).appendDown(a.createZero(1, 1));
            DoubleMatrix a2t2 = a2t;
            DoubleMatrix tmp2 = a2t2.multiply(y.getElement(i));
            z = tmp1.subtract(tmp2);
          }
        } else {
          for (int i = 1; i <= nx; i++) {
            y.setElement(i, 1, b.getElement(1).multiply(x.getElement(i)).add(z.getElement(1)));
            DoubleMatrix tmp1 = z.getSubMatrix(2, z.getRowSize(), 1, 1).appendDown(a.createZero(1, 1));
            DoubleMatrix b2t2 = b2t;
            DoubleMatrix tmp2 = b2t2.multiply(x.getElement(i)).appendDown(zab);
            DoubleMatrix a2t2 = a2t;
            DoubleMatrix tmp3 = a2t2.multiply(y.getElement(i));
            z = tmp1.add(tmp2).subtract(tmp3);
          }
        }
      }
    }

    if (x_.getColumnSize() > x_.getRowSize()) {
      y = y.transpose();
      z = z.transpose();
    }

    return new ArrayList<>(Arrays.asList(new DoubleMatrix[] {y, z}));
  }

  /**
   * データ<code>x</code>をフィルタ
   * 
   * <pre><code> y(n) = b(1)*x(n) + b(2)*x(n-1) + ... + b(nb+1)*x(n-nb) - a(2)*y(n-1) - ... - a(na+1)*y(n-na) </code></pre>
   * 
   * に通した値を求めます。
   * 
   * @param b 分子の係数
   * @param a 分母の係数
   * @param x 入力信号
   * @return 出力信号 (filtered signal)
   * @param <RS> type of real scalar
   * @param <RM> type of real matrix
   * @param <CS> type of complex scalar
   * @param <CM> type of complex matrix
   */
  public static <RS extends RealNumericalScalar<RS, RM, CS, CM>, RM extends RealNumericalMatrix<RS, RM, CS, CM>, CS extends ComplexNumericalScalar<RS, RM, CS, CM>, CM extends ComplexNumericalMatrix<RS, RM, CS, CM>> List<RM> filter(
      RM b, RM a, RM x) {
    int na = a.length() - 1;
    int nb = b.length() - 1;
    // int nx = x.length();
    RM zi = a.createZero(Math.max(na, nb), 1);

    return filter(b, a, x, zi);
  }

  /**
   * 初期条件と終端条件を与える。
   * 
   * @param b_ 分子の係数
   * @param a_ 分母の係数
   * @param x_ 入力信号
   * @param zi_ フィルターの初期状態
   * @return 出力信号 (filtered signal)
   * @param <RS> type of real scalar
   * @param <RM> type of real matrix
   * @param <CS> type of complex scalar
   * @param <CM> type of complex matrix
   */
  public static <RS extends RealNumericalScalar<RS, RM, CS, CM>, RM extends RealNumericalMatrix<RS, RM, CS, CM>, CS extends ComplexNumericalScalar<RS, RM, CS, CM>, CM extends ComplexNumericalMatrix<RS, RM, CS, CM>> List<RM> filter(
      RM b_, RM a_, RM x_, RM zi_) {
    RM b = Makerowv.makerowv(b_);
    RM a = Makerowv.makerowv(a_);
    RM x = Makerowv.makerowv(x_);

    int na = a.length() - 1;
    int nb = b.length() - 1;
    int nx = x.length();

    RM zi = Makecolv.makecolv(zi_);

    RS eps = a.getElement(1, 1).getMachineEpsilon();

    if (na > 0) {
      if (a.getElement(1).abs().isLessThanOrEquals(eps)) {
        System.err.println(Messages.getString("Filter.0")); //$NON-NLS-1$
      }
      a = a.multiply(a.getElement(1).inverse());
      b = b.multiply(a.getElement(1).inverse());
    }

    RM a2t = null;
    if (na > 0) {
      a2t = a.getSubVector(2, a.length()).transpose();
    }
    RM b2t = null;
    if (nb > 0) {
      b2t = b.getSubVector(2, b.length()).transpose();
    }

    RM y = a.createZero(nx, 1);

    RM z = zi;
    RM zab = a.createZero(Math.abs(na - nb), 1);

    if (nb > na) {
      if (nb == 1) {
        for (int i = 1; i <= nx; i++) {
          y.setElement(i, 1, b.getElement(1).multiply(x.getElement(i)).add(z.getElement(1)));
          RM b2t2 = b2t;
          z = b2t2.multiply(x.getElement(i));
        }
      } else {
        if (na == 0) {
          for (int i = 1; i <= nx; i++) {
            y.setElement(i, 1, b.getElement(1).multiply(x.getElement(i)).add(z.getElement(1)));
            RM b2t2 = b2t;
            z = z.getSubMatrix(2, z.getRowSize(), 1, 1).appendDown(a.createZero(1, 1)).add(b2t2.multiply(x.getElement(i)));
          }
        } else {
          for (int i = 1; i <= nx; i++) {
            y.setElement(i, 1, b.getElement(1).multiply(x.getElement(i)).add(z.getElement(1)));
            RM tmp1 = z.getSubMatrix(2, z.getRowSize(), 1, 1).appendDown(a.createZero(1, 1));
            RM b2t2 = b2t;
            RM tmp2 = b2t2.multiply(x.getElement(i));
            RM a2t2 = a2t;
            RM tmp3 = a2t2.multiply(y.getElement(i)).appendDown(zab);
            z = tmp1.add(tmp2).subtract(tmp3);
          }
        }
      }
    } else if (nb == na) {
      if (na == 0) {
        for (int i = 1; i <= nx; i++) {
          y.setElement(i, 1, b.getElement(1).multiply(x.getElement(i)));
        }
      } else if (na == 1) {
        for (int i = 1; i <= nx; i++) {
          y.setElement(i, 1, b.getElement(1).multiply(x.getElement(i)).add(z.getElement(1)));
          RM a2t2 = a2t;
          RM b2t2 = b2t;
          z = b2t2.multiply(x.getElement(i)).subtract(a2t2.multiply(y.getElement(i)));
        }
      } else {
        for (int i = 1; i <= nx; i++) {
          y.setElement(i, 1, b.getElement(1).multiply(x.getElement(i)).add(z.getElement(1)));
          RM tmp1 = z.getSubMatrix(2, z.getRowSize(), 1, 1).appendDown(a.createZero(1, 1));
          RM b2t2 = b2t;
          RM tmp2 = b2t2.multiply(x.getElement(i));
          RM a2t2 = a2t;
          RM tmp3 = a2t2.multiply(y.getElement(i));
          z = tmp1.add(tmp2).subtract(tmp3);
        }
      }
    } else {
      if (na == 1) {
        for (int i = 1; i <= nx; i++) {
          y.setElement(i, 1, b.getElement(1).multiply(x.getElement(i)).add(z.getElement(1)));
          RM a2t2 = a2t;
          z = a2t2.multiply(y.getElement(i)).unaryMinus();
        }
      } else {
        if (nb == 0) {
          for (int i = 1; i <= nx; i++) {
            y.setElement(i, 1, b.getElement(1).multiply(x.getElement(i)).add(z.getElement(1)));
            RM tmp1 = z.getSubMatrix(2, z.getRowSize(), 1, 1).appendDown(a.createZero(1, 1));
            RM a2t2 = a2t;
            RM tmp2 = a2t2.multiply(y.getElement(i));
            z = tmp1.subtract(tmp2);
          }
        } else {
          for (int i = 1; i <= nx; i++) {
            y.setElement(i, 1, b.getElement(1).multiply(x.getElement(i)).add(z.getElement(1)));
            RM tmp1 = z.getSubMatrix(2, z.getRowSize(), 1, 1).appendDown(a.createZero(1, 1));
            RM b2t2 = b2t;
            RM tmp2 = b2t2.multiply(x.getElement(i)).appendDown(zab);
            RM a2t2 = a2t;
            RM tmp3 = a2t2.multiply(y.getElement(i));
            z = tmp1.add(tmp2).subtract(tmp3);
          }
        }
      }
    }

    if (x_.getColumnSize() > x_.getRowSize()) {
      y = y.transpose();
      z = z.transpose();
    }

    List<RM> yz = new ArrayList<>();
    yz.add(y);
    yz.add(z);
    return yz;
  }

}