Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

DiscreteFactor::errorTree method for all assignments #1669

Merged
merged 3 commits into from
Jan 7, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
16 changes: 16 additions & 0 deletions gtsam/discrete/DecisionTreeFactor.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -62,6 +62,22 @@ namespace gtsam {
return error(values.discrete());
}

/* ************************************************************************ */
AlgebraicDecisionTree<Key> DecisionTreeFactor::errorTree() const {
// Get all possible assignments
DiscreteKeys dkeys = discreteKeys();
// Reverse to make cartesian product output a more natural ordering.
DiscreteKeys rdkeys(dkeys.rbegin(), dkeys.rend());
const auto assignments = DiscreteValues::CartesianProduct(rdkeys);

// Construct vector with error values
std::vector<double> errors;
for (const auto& assignment : assignments) {
errors.push_back(error(assignment));
}
return AlgebraicDecisionTree<Key>(dkeys, errors);
}

/* ************************************************************************ */
double DecisionTreeFactor::safe_div(const double& a, const double& b) {
// The use for safe_div is when we divide the product factor by the sum
Expand Down
3 changes: 3 additions & 0 deletions gtsam/discrete/DecisionTreeFactor.h
Original file line number Diff line number Diff line change
Expand Up @@ -292,6 +292,9 @@ namespace gtsam {
*/
double error(const HybridValues& values) const override;

/// Compute error for each assignment and return as a tree
AlgebraicDecisionTree<Key> errorTree() const override;

/// @}

private:
Expand Down
15 changes: 10 additions & 5 deletions gtsam/discrete/DiscreteFactor.h
Original file line number Diff line number Diff line change
Expand Up @@ -18,9 +18,10 @@

#pragma once

#include <gtsam/base/Testable.h>
#include <gtsam/discrete/AlgebraicDecisionTree.h>
#include <gtsam/discrete/DiscreteValues.h>
#include <gtsam/inference/Factor.h>
#include <gtsam/base/Testable.h>

#include <string>
namespace gtsam {
Expand All @@ -35,7 +36,7 @@ class HybridValues;
*
* @ingroup discrete
*/
class GTSAM_EXPORT DiscreteFactor: public Factor {
class GTSAM_EXPORT DiscreteFactor : public Factor {
public:
// typedefs needed to play nice with gtsam
typedef DiscreteFactor This; ///< This class
Expand Down Expand Up @@ -103,15 +104,19 @@ class GTSAM_EXPORT DiscreteFactor: public Factor {
*/
double error(const HybridValues& c) const override;

/// Multiply in a DecisionTreeFactor and return the result as DecisionTreeFactor
/// Compute error for each assignment and return as a tree
virtual AlgebraicDecisionTree<Key> errorTree() const = 0;

/// Multiply in a DecisionTreeFactor and return the result as
/// DecisionTreeFactor
virtual DecisionTreeFactor operator*(const DecisionTreeFactor&) const = 0;

virtual DecisionTreeFactor toDecisionTreeFactor() const = 0;

/// @}
/// @name Wrapper support
/// @{

/// Translation table from values to strings.
using Names = DiscreteValues::Names;

Expand Down Expand Up @@ -175,4 +180,4 @@ template<> struct traits<DiscreteFactor> : public Testable<DiscreteFactor> {};
std::vector<double> expNormalize(const std::vector<double> &logProbs);


}// namespace gtsam
} // namespace gtsam
5 changes: 5 additions & 0 deletions gtsam/discrete/TableFactor.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -168,6 +168,11 @@ double TableFactor::error(const HybridValues& values) const {
return error(values.discrete());
}

/* ************************************************************************ */
AlgebraicDecisionTree<Key> TableFactor::errorTree() const {
return toDecisionTreeFactor().errorTree();
}

/* ************************************************************************ */
DecisionTreeFactor TableFactor::operator*(const DecisionTreeFactor& f) const {
return toDecisionTreeFactor() * f;
Expand Down
3 changes: 3 additions & 0 deletions gtsam/discrete/TableFactor.h
Original file line number Diff line number Diff line change
Expand Up @@ -358,6 +358,9 @@ class GTSAM_EXPORT TableFactor : public DiscreteFactor {
*/
double error(const HybridValues& values) const override;

/// Compute error for each assignment and return as a tree
AlgebraicDecisionTree<Key> errorTree() const override;

/// @}
};

Expand Down
18 changes: 18 additions & 0 deletions gtsam/discrete/tests/testDecisionTreeFactor.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -67,6 +67,24 @@ TEST( DecisionTreeFactor, constructors)
EXPECT_DOUBLES_EQUAL(0.8, f4(x121), 1e-9);
}

/* ************************************************************************* */
TEST(DecisionTreeFactor, Error) {
// Declare a bunch of keys
DiscreteKey X(0,2), Y(1,3), Z(2,2);

// Create factors
DecisionTreeFactor f(X & Y & Z, "2 5 3 6 4 7 25 55 35 65 45 75");

auto errors = f.errorTree();
// regression
AlgebraicDecisionTree<Key> expected(
{X, Y, Z},
vector<double>{-0.69314718, -1.6094379, -1.0986123, -1.7917595,
-1.3862944, -1.9459101, -3.2188758, -4.0073332, -3.5553481,
-4.1743873, -3.8066625, -4.3174881});
EXPECT(assert_equal(expected, errors, 1e-6));
}

/* ************************************************************************* */
TEST(DecisionTreeFactor, multiplication) {
DiscreteKey v0(0, 2), v1(1, 2), v2(2, 2);
Expand Down
6 changes: 3 additions & 3 deletions gtsam/hybrid/GaussianMixture.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -313,14 +313,14 @@ AlgebraicDecisionTree<Key> GaussianMixture::logProbability(
}

/* *******************************************************************************/
AlgebraicDecisionTree<Key> GaussianMixture::error(
AlgebraicDecisionTree<Key> GaussianMixture::errorTree(
const VectorValues &continuousValues) const {
auto errorFunc = [&](const GaussianConditional::shared_ptr &conditional) {
return conditional->error(continuousValues) + //
logConstant_ - conditional->logNormalizationConstant();
};
DecisionTree<Key, double> errorTree(conditionals_, errorFunc);
return errorTree;
DecisionTree<Key, double> error_tree(conditionals_, errorFunc);
return error_tree;
}

/* *******************************************************************************/
Expand Down
2 changes: 1 addition & 1 deletion gtsam/hybrid/GaussianMixture.h
Original file line number Diff line number Diff line change
Expand Up @@ -214,7 +214,7 @@ class GTSAM_EXPORT GaussianMixture
* @return AlgebraicDecisionTree<Key> A decision tree on the discrete keys
* only, with the leaf values as the error for each assignment.
*/
AlgebraicDecisionTree<Key> error(const VectorValues &continuousValues) const;
AlgebraicDecisionTree<Key> errorTree(const VectorValues &continuousValues) const;

/**
* @brief Compute the logProbability of this Gaussian Mixture.
Expand Down
6 changes: 3 additions & 3 deletions gtsam/hybrid/GaussianMixtureFactor.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -102,14 +102,14 @@ GaussianFactorGraphTree GaussianMixtureFactor::asGaussianFactorGraphTree()
}

/* *******************************************************************************/
AlgebraicDecisionTree<Key> GaussianMixtureFactor::error(
AlgebraicDecisionTree<Key> GaussianMixtureFactor::errorTree(
const VectorValues &continuousValues) const {
// functor to convert from sharedFactor to double error value.
auto errorFunc = [&continuousValues](const sharedFactor &gf) {
return gf->error(continuousValues);
};
DecisionTree<Key, double> errorTree(factors_, errorFunc);
return errorTree;
DecisionTree<Key, double> error_tree(factors_, errorFunc);
return error_tree;
}

/* *******************************************************************************/
Expand Down
2 changes: 1 addition & 1 deletion gtsam/hybrid/GaussianMixtureFactor.h
Original file line number Diff line number Diff line change
Expand Up @@ -135,7 +135,7 @@ class GTSAM_EXPORT GaussianMixtureFactor : public HybridFactor {
* @return AlgebraicDecisionTree<Key> A decision tree with the same keys
* as the factors involved, and leaf values as the error.
*/
AlgebraicDecisionTree<Key> error(const VectorValues &continuousValues) const;
AlgebraicDecisionTree<Key> errorTree(const VectorValues &continuousValues) const;

/**
* @brief Compute the log-likelihood, including the log-normalizing constant.
Expand Down
6 changes: 3 additions & 3 deletions gtsam/hybrid/HybridGaussianFactorGraph.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -420,7 +420,7 @@ EliminateHybrid(const HybridGaussianFactorGraph &factors,
}

/* ************************************************************************ */
AlgebraicDecisionTree<Key> HybridGaussianFactorGraph::error(
AlgebraicDecisionTree<Key> HybridGaussianFactorGraph::errorTree(
const VectorValues &continuousValues) const {
AlgebraicDecisionTree<Key> error_tree(0.0);

Expand All @@ -431,7 +431,7 @@ AlgebraicDecisionTree<Key> HybridGaussianFactorGraph::error(

if (auto gaussianMixture = dynamic_pointer_cast<GaussianMixtureFactor>(f)) {
// Compute factor error and add it.
error_tree = error_tree + gaussianMixture->error(continuousValues);
error_tree = error_tree + gaussianMixture->errorTree(continuousValues);
} else if (auto gaussian = dynamic_pointer_cast<GaussianFactor>(f)) {
// If continuous only, get the (double) error
// and add it to the error_tree
Expand Down Expand Up @@ -460,7 +460,7 @@ double HybridGaussianFactorGraph::probPrime(const HybridValues &values) const {
/* ************************************************************************ */
AlgebraicDecisionTree<Key> HybridGaussianFactorGraph::probPrime(
const VectorValues &continuousValues) const {
AlgebraicDecisionTree<Key> error_tree = this->error(continuousValues);
AlgebraicDecisionTree<Key> error_tree = this->errorTree(continuousValues);
AlgebraicDecisionTree<Key> prob_tree = error_tree.apply([](double error) {
// NOTE: The 0.5 term is handled by each factor
return exp(-error);
Expand Down
3 changes: 2 additions & 1 deletion gtsam/hybrid/HybridGaussianFactorGraph.h
Original file line number Diff line number Diff line change
Expand Up @@ -161,7 +161,8 @@ class GTSAM_EXPORT HybridGaussianFactorGraph
* @param continuousValues Continuous values at which to compute the error.
* @return AlgebraicDecisionTree<Key>
*/
AlgebraicDecisionTree<Key> error(const VectorValues& continuousValues) const;
AlgebraicDecisionTree<Key> errorTree(
const VectorValues& continuousValues) const;

/**
* @brief Compute unnormalized probability \f$ P(X | M, Z) \f$
Expand Down
6 changes: 3 additions & 3 deletions gtsam/hybrid/MixtureFactor.h
Original file line number Diff line number Diff line change
Expand Up @@ -131,13 +131,13 @@ class MixtureFactor : public HybridFactor {
* @return AlgebraicDecisionTree<Key> A decision tree with the same keys
* as the factor, and leaf values as the error.
*/
AlgebraicDecisionTree<Key> error(const Values& continuousValues) const {
AlgebraicDecisionTree<Key> errorTree(const Values& continuousValues) const {
// functor to convert from sharedFactor to double error value.
auto errorFunc = [continuousValues](const sharedFactor& factor) {
return factor->error(continuousValues);
};
DecisionTree<Key, double> errorTree(factors_, errorFunc);
return errorTree;
DecisionTree<Key, double> result(factors_, errorFunc);
return result;
}

/**
Expand Down
4 changes: 2 additions & 2 deletions gtsam/hybrid/tests/testGaussianMixture.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -97,7 +97,7 @@ TEST(GaussianMixture, LogProbability) {
/// Check error.
TEST(GaussianMixture, Error) {
using namespace equal_constants;
auto actual = mixture.error(vv);
auto actual = mixture.errorTree(vv);

// Check result.
std::vector<DiscreteKey> discrete_keys = {mode};
Expand Down Expand Up @@ -134,7 +134,7 @@ TEST(GaussianMixture, Likelihood) {
std::vector<double> leaves = {conditionals[0]->likelihood(vv)->error(vv),
conditionals[1]->likelihood(vv)->error(vv)};
AlgebraicDecisionTree<Key> expected(discrete_keys, leaves);
EXPECT(assert_equal(expected, likelihood->error(vv), 1e-6));
EXPECT(assert_equal(expected, likelihood->errorTree(vv), 1e-6));

// Check that the ratio of probPrime to evaluate is the same for all modes.
std::vector<double> ratio(2);
Expand Down
2 changes: 1 addition & 1 deletion gtsam/hybrid/tests/testGaussianMixtureFactor.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -178,7 +178,7 @@ TEST(GaussianMixtureFactor, Error) {
continuousValues.insert(X(2), Vector2(1, 1));

// error should return a tree of errors, with nodes for each discrete value.
AlgebraicDecisionTree<Key> error_tree = mixtureFactor.error(continuousValues);
AlgebraicDecisionTree<Key> error_tree = mixtureFactor.errorTree(continuousValues);

std::vector<DiscreteKey> discrete_keys = {m1};
// Error values for regression test
Expand Down
2 changes: 1 addition & 1 deletion gtsam/hybrid/tests/testHybridGaussianFactorGraph.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -580,7 +580,7 @@ TEST(HybridGaussianFactorGraph, ErrorAndProbPrimeTree) {
HybridBayesNet::shared_ptr hybridBayesNet = graph.eliminateSequential();

HybridValues delta = hybridBayesNet->optimize();
auto error_tree = graph.error(delta.continuous());
auto error_tree = graph.errorTree(delta.continuous());

std::vector<DiscreteKey> discrete_keys = {{M(0), 2}, {M(1), 2}};
std::vector<double> leaves = {0.9998558, 0.4902432, 0.5193694, 0.0097568};
Expand Down
3 changes: 2 additions & 1 deletion gtsam/hybrid/tests/testMixtureFactor.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -97,7 +97,8 @@ TEST(MixtureFactor, Error) {
continuousValues.insert<double>(X(1), 0);
continuousValues.insert<double>(X(2), 1);

AlgebraicDecisionTree<Key> error_tree = mixtureFactor.error(continuousValues);
AlgebraicDecisionTree<Key> error_tree =
mixtureFactor.errorTree(continuousValues);

DiscreteKey m1(1, 2);
std::vector<DiscreteKey> discrete_keys = {m1};
Expand Down
5 changes: 5 additions & 0 deletions gtsam_unstable/discrete/AllDiff.h
Original file line number Diff line number Diff line change
Expand Up @@ -53,6 +53,11 @@ class GTSAM_UNSTABLE_EXPORT AllDiff : public Constraint {
/// Multiply into a decisiontree
DecisionTreeFactor operator*(const DecisionTreeFactor& f) const override;

/// Compute error for each assignment and return as a tree
AlgebraicDecisionTree<Key> errorTree() const override {
throw std::runtime_error("AllDiff::error not implemented");
}

/*
* Ensure Arc-consistency by checking every possible value of domain j.
* @param j domain to be checked
Expand Down
5 changes: 5 additions & 0 deletions gtsam_unstable/discrete/BinaryAllDiff.h
Original file line number Diff line number Diff line change
Expand Up @@ -91,6 +91,11 @@ class BinaryAllDiff : public Constraint {
const Domains&) const override {
throw std::runtime_error("BinaryAllDiff::partiallyApply not implemented");
}

/// Compute error for each assignment and return as a tree
AlgebraicDecisionTree<Key> errorTree() const override {
throw std::runtime_error("BinaryAllDiff::error not implemented");
}
};

} // namespace gtsam
5 changes: 5 additions & 0 deletions gtsam_unstable/discrete/Domain.h
Original file line number Diff line number Diff line change
Expand Up @@ -69,6 +69,11 @@ class GTSAM_UNSTABLE_EXPORT Domain : public Constraint {
}
}

/// Compute error for each assignment and return as a tree
AlgebraicDecisionTree<Key> errorTree() const override {
throw std::runtime_error("Domain::error not implemented");
}

// Return concise string representation, mostly to debug arc consistency.
// Converts from base 0 to base1.
std::string base1Str() const;
Expand Down
5 changes: 5 additions & 0 deletions gtsam_unstable/discrete/SingleValue.h
Original file line number Diff line number Diff line change
Expand Up @@ -49,6 +49,11 @@ class GTSAM_UNSTABLE_EXPORT SingleValue : public Constraint {
}
}

/// Compute error for each assignment and return as a tree
AlgebraicDecisionTree<Key> errorTree() const override {
throw std::runtime_error("SingleValue::error not implemented");
}

/// Calculate value
double operator()(const DiscreteValues& values) const override;

Expand Down