Lqr.java

/*
 * $Id: Lqr.java,v 1.39 2008/07/17 07:30:03 koga Exp $
 *
 * Copyright (C) 2004 Koga Laboratory. All rights reserved.
 */
package org.mklab.tool.control;

import java.util.ArrayList;
import java.util.Arrays;
import java.util.List;

import org.mklab.nfc.eig.DoubleEigenSolution;
import org.mklab.nfc.eig.EigenSolution;
import org.mklab.nfc.matrix.ComplexNumericalMatrix;
import org.mklab.nfc.matrix.DoubleComplexMatrix;
import org.mklab.nfc.matrix.DoubleMatrix;
import org.mklab.nfc.matrix.IndexedMatrix;
import org.mklab.nfc.matrix.IntMatrix;
import org.mklab.nfc.matrix.NormType;
import org.mklab.nfc.matrix.RealNumericalMatrix;
import org.mklab.nfc.scalar.ComplexNumericalScalar;
import org.mklab.nfc.scalar.DoubleNumber;
import org.mklab.nfc.scalar.RealNumericalScalar;


/**
 * 連続時間システムのLQRを求めるクラスです。
 * 
 * <p>Continuous-time linear quadratic regulator
 * 
 * @author koga
 * @version $Revision: 1.39 $
 * @see org.mklab.tool.control.Dlqr
 * @see org.mklab.tool.control.Lqe
 */
public class Lqr {

  /**
   * 連続時間線形システム
   * 
   * <pre><code> dx/dt = Ax + Bu </code></pre>
   * 
   * について、二次形式評価関数
   * 
   * <pre><code> J = Integral (x#Qx + u#Ru) dt </code></pre>
   * 
   * を最小にする、最適状態フィードバック則<code>u = -Fx</code>のフィードバックゲイン行列<code>F</code>と
   * 
   * <p>リカッティ方程式
   * 
   * <pre><code> P A + A# P - P B R&tilde; B# P + Q = 0 </code></pre>
   * 
   * の解<code>P</code>を要素とするリストを返します。
   * 
   * @param A 連続時間系のシステム行列 A
   * @param B 連続時間系のシステム行列 B
   * @param Q 重み行列(状態)
   * @param R 重み行列(入力)
   * @return {F,P} (状態フィードバックゲイン, リカッティ方程式の解)
   */
  public static List<DoubleMatrix> lqr(DoubleMatrix A, DoubleMatrix B, DoubleMatrix Q, DoubleMatrix R) {
    final DoubleMatrix P = getRiccatiSolution(A, B, Q, R);
    final DoubleMatrix F = R.leftDivide(B.conjugateTranspose().multiply(P));

    return new ArrayList<>(Arrays.asList(new DoubleMatrix[] {F, P}));
  }

  /**
   * 連続時間線形システム
   * 
   * <pre><code> dx/dt = Ax + Bu </code></pre>
   * 
   * について、二次形式評価関数
   * 
   * <pre><code> J = Integral (x#Qx + u#Ru) dt </code></pre>
   * 
   * を最小にする、最適状態フィードバック則<code>u = -Fx</code>のフィードバックゲイン行列<code>F</code>と
   * 
   * <p>リカッティ方程式
   * 
   * <pre><code> P A + A# P - P B R&tilde; B# P + Q = 0 </code></pre>
   * 
   * の解<code>P</code>を要素とするリストを返します。
   * 
   * @param <RS> スカラーの型
   * @param <RM> 行列の型
   * @param <CS> 複素スカラーの型
   * @param <CM> 複素行列の型
   * @param A 連続時間系のシステム行列 A
   * @param B 連続時間系のシステム行列 B
   * @param Q 重み行列(状態)
   * @param R 重み行列(入力)
   * @return {F,P} (状態フィードバックゲイン, リカッティ方程式の解)
   */
  public static <RS extends RealNumericalScalar<RS, RM, CS, CM>, RM extends RealNumericalMatrix<RS, RM, CS, CM>, CS extends ComplexNumericalScalar<RS, RM, CS, CM>, CM extends ComplexNumericalMatrix<RS, RM, CS, CM>> List<RM> lqr(
      RM A, RM B, RM Q, RM R) {
    final RM P = getRiccatiSolution(A, B, Q, R);
    final RM F = R.leftDivide(B.conjugateTranspose().multiply(P));
    
    final List<RM> ans = new ArrayList<>();
    ans.add(F);
    ans.add(P);

    return ans;
  }

  /**
   * 評価関数が
   * 
   * <pre><code> J = Integral (x#Qx + u#Ru + 2 x#Su) dt </code></pre>
   * 
   * となるよう、入力<code>u</code>と状態<code>x</code>の積の重み行列を <code>S</code>とします。
   * 
   * @param A システム行列
   * @param B 入力行列
   * @param Q 状態に関する重み行列
   * @param R 入力に関する重み行列
   * @param S 入力と状態に関する重み行列
   * @return {F,P} (状態フィードバックゲイン, リカッティ方程式の解)
   */
  public static List<DoubleMatrix> lqr(DoubleMatrix A, DoubleMatrix B, DoubleMatrix Q, DoubleMatrix R, DoubleMatrix S) {
    final DoubleMatrix AA = A.subtract(B.divide(R).multiply(S.conjugateTranspose()));
    final DoubleMatrix QQ = Q.subtract(S.divide(R).multiply(S.conjugateTranspose()));

    if (S.getRowSize() != Q.getRowSize() || S.getColumnSize() != R.getColumnSize()) {
      throw new IllegalArgumentException("Lqr: " + Messages.getString("Lqr.0")); //$NON-NLS-1$ //$NON-NLS-2$
    }

    final DoubleMatrix P = getRiccatiSolution(AA, B, QQ, R);
    final DoubleMatrix F = R.leftDivide(S.conjugateTranspose().add(B.conjugateTranspose().multiply(P)));

    return new ArrayList<>(Arrays.asList(new DoubleMatrix[] {F, P}));
  }

  /**
   * 連続時間系のリカッティ方程式の解を返します。
   * 
   * @param A システム行列 A
   * @param B システム行列 B
   * @param Q 重み行列(状態)
   * @param R 重み行列(入力)
   * @return P リカッティ方程式の解
   */
  private static DoubleMatrix getRiccatiSolution(DoubleMatrix A, DoubleMatrix B, DoubleMatrix Q, DoubleMatrix R) {
    final String message;
    if ((message = Abcdchk.abcdchk(A, B)).length() > 0) {
      throw new IllegalArgumentException(message);
    }

    if (A.isSameSize(Q) == false) {
      throw new IllegalArgumentException("Lqr: " + Messages.getString("Lqr.2")); //$NON-NLS-1$ //$NON-NLS-2$
    }
    if (B.getColumnSize() != R.getRowSize()) {
      throw new IllegalArgumentException("Lqr: " + Messages.getString("Lqr.3")); //$NON-NLS-1$ //$NON-NLS-2$
    }

    final boolean qIsPositive = isPositive(Q);
    final boolean qIsSymmetric = isSymmetric(Q);

    if (qIsPositive == false || qIsSymmetric == false) {
      throw new IllegalArgumentException("Lqr: " + Messages.getString("Lqr.4")); //$NON-NLS-1$ //$NON-NLS-2$
    }

    final boolean rIsPositive = isPositive(R);
    final boolean rIsSymmetric = isSymmetric(R);

    if (rIsPositive == false || rIsSymmetric == false) {
      throw new IllegalArgumentException("Lqr: " + Messages.getString("Lqr.5")); //$NON-NLS-1$ //$NON-NLS-2$
    }

    final DoubleMatrix h = A.appendRight(B.divide(R).multiply(B.conjugateTranspose())).appendDown(Q.appendRight(A.conjugateTranspose().unaryMinus()));
    final DoubleEigenSolution dv = h.eigenDecompose();
    final DoubleComplexMatrix D = dv.getValue().diagonalToVector();
    final DoubleComplexMatrix vector = dv.getVector();

    // Sort on real part of eigenvalues
    final IndexedMatrix<DoubleNumber, DoubleMatrix> list5 = D.getRealPart().sort();
    final IntMatrix idx = list5.getIndices().transpose();
    final DoubleMatrix Dr = list5.getMatrix();

    final int size = A.getRowSize();

    if (!(Dr.getElement(size).isLessThan(0) && Dr.getElement(size + 1).isGreaterThan(0))) {
      throw new IllegalArgumentException("Lqr: " + Messages.getString("Lqr.6")); //$NON-NLS-1$ //$NON-NLS-2$
    }

    final DoubleComplexMatrix u1 = vector.getSubMatrix(size + 1, 2 * size, idx.getSubVector(1, size));
    final DoubleComplexMatrix v1 = vector.getSubMatrix(1, size, idx.getSubVector(1, size));
    final DoubleMatrix P = u1.divide(v1).getRealPart().unaryMinus();
    return P;
  }

  private static <RS extends RealNumericalScalar<RS, RM, CS, CM>, RM extends RealNumericalMatrix<RS, RM, CS, CM>, CS extends ComplexNumericalScalar<RS, RM, CS, CM>, CM extends ComplexNumericalMatrix<RS, RM, CS, CM>> RM getRiccatiSolution(
      RM A, RM B, RM Q, RM R) {
    final String message;
    if ((message = Abcdchk.abcdchk(A, B)).length() > 0) {
      throw new IllegalArgumentException(message);
    }

    if (A.isSameSize(Q) == false) {
      throw new IllegalArgumentException("Lqr: " + Messages.getString("Lqr.2")); //$NON-NLS-1$ //$NON-NLS-2$
    }
    if (B.getColumnSize() != R.getRowSize()) {
      throw new IllegalArgumentException("Lqr: " + Messages.getString("Lqr.3")); //$NON-NLS-1$ //$NON-NLS-2$
    }

    final boolean qIsPositive = isPositive(Q);
    final boolean qIsSymmetric = isSymmetric(Q);

    if (qIsPositive == false || qIsSymmetric == false) {
      throw new IllegalArgumentException("Lqr: " + Messages.getString("Lqr.4")); //$NON-NLS-1$ //$NON-NLS-2$
    }

    final boolean rIsPositive = isPositive(R);
    final boolean rIsSymmetric = isSymmetric(R);

    if (rIsPositive == false || rIsSymmetric == false) {
      throw new IllegalArgumentException("Lqr: " + Messages.getString("Lqr.5")); //$NON-NLS-1$ //$NON-NLS-2$
    }

    final RM h = A.appendRight(B.divide(R).multiply(B.conjugateTranspose())).appendDown(Q.appendRight(A.conjugateTranspose().unaryMinus()));
    final EigenSolution<RS, RM, CS, CM> dv = h.eigenDecompose();
    final CM D = dv.getValue().diagonalToVector();
    final CM vector = dv.getVector();

    // Sort on real part of eigenvalues
    final IndexedMatrix<RS, RM> list5 = D.getRealPart().sort();
    final IntMatrix idx = list5.getIndices().transpose();
    final RM Dr = list5.getMatrix();

    final int size = A.getRowSize();

    if (!(Dr.getElement(size).isLessThan(0) && Dr.getElement(size + 1).isGreaterThan(0))) {
      throw new IllegalArgumentException("Lqr: " + Messages.getString("Lqr.6")); //$NON-NLS-1$ //$NON-NLS-2$
    }

    final CM u1 = vector.getSubMatrix(size + 1, 2 * size, idx.getSubVector(1, size));
    final CM v1 = vector.getSubMatrix(1, size, idx.getSubVector(1, size));
    final RM P = u1.divide(v1).getRealPart().unaryMinus();
    return P;
  }

  /**
   * 対称行列であるか判定します。
   * 
   * @param a 対象となる行列
   * @return 対称ならばtrue、そうでなければfalse
   */
  private static boolean isSymmetric(DoubleMatrix a) {
    final DoubleNumber frobNorm = a.frobNorm();
    final DoubleNumber tolerance = frobNorm.multiply(frobNorm.getMachineEpsilon());

    final DoubleNumber difference = a.transpose().subtract(a).norm(NormType.ONE);
    return difference.divide(a.norm(NormType.ONE)).isLessThan(tolerance);
  }

  /**
   * 対称行列であるか判定します。
   * 
   * @param <CS> 複素スカラーの型
   * @param <CM> 複素行列の型
   * @param <RS> スカラーの型
   * @param <RM> 行列の型
   * @param a 対象となる行列
   * @return 対称ならばtrue、そうでなければfalse
   */
  private static <RS extends RealNumericalScalar<RS, RM, CS, CM>, RM extends RealNumericalMatrix<RS, RM, CS, CM>, CS extends ComplexNumericalScalar<RS, RM, CS, CM>, CM extends ComplexNumericalMatrix<RS, RM, CS, CM>> boolean isSymmetric(
      RM a) {
    final RS frobNorm = a.frobNorm();
    final RS tolerance = frobNorm.multiply(frobNorm.getMachineEpsilon());

    final RS difference = a.transpose().subtract(a).norm(NormType.ONE);
    return difference.divide(a.norm(NormType.ONE)).isLessThan(tolerance);
  }

  /**
   * 正定行列であるか判定します。
   * 
   * @param a 対象となる行列
   * @return 正定ならばtrue、そうでなければfalse
   */
  private static boolean isPositive(DoubleMatrix a) {
    final DoubleNumber frobNorm = a.frobNorm();
    final DoubleNumber tolerance = frobNorm.multiply(frobNorm.getMachineEpsilon());

    return a.eigenValue().getRealPart().compareElementWise(".<", tolerance.unaryMinus()).anyTrue() == false; //$NON-NLS-1$
  }

  /**
   * 正定行列であるか判定します。
   * 
   * @param <CS> 複素スカラーの型
   * @param <CM> 複素行列の型
   * @param <RS> スカラーの型
   * @param <RM> 行列の型
   * @param a 対象となる行列
   * @return 正定ならばtrue、そうでなければfalse
   */
  private static <RS extends RealNumericalScalar<RS, RM, CS, CM>, RM extends RealNumericalMatrix<RS, RM, CS, CM>, CS extends ComplexNumericalScalar<RS, RM, CS, CM>, CM extends ComplexNumericalMatrix<RS, RM, CS, CM>> boolean isPositive(
      RM a) {
    final RS frobNorm = a.frobNorm();
    final RS tolerance = frobNorm.multiply(frobNorm.getMachineEpsilon());

    return a.eigenValue().getRealPart().compareElementWise(".<", tolerance.unaryMinus()).anyTrue() == false; //$NON-NLS-1$
  }

}