Skip to content

Commit

Permalink
Merge pull request #1919 from borglab/discrete-elimination-refactor
Browse files Browse the repository at this point in the history
  • Loading branch information
varunagrawal authored Jan 7, 2025
2 parents 47074bd + c754f9b commit 82d0ebc
Show file tree
Hide file tree
Showing 18 changed files with 188 additions and 60 deletions.
19 changes: 19 additions & 0 deletions gtsam/discrete/DecisionTreeFactor.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -87,6 +87,25 @@ namespace gtsam {
return result;
}

/* ************************************************************************ */
DiscreteFactor::shared_ptr DecisionTreeFactor::operator/(
const DiscreteFactor::shared_ptr& f) const {
if (auto tf = std::dynamic_pointer_cast<TableFactor>(f)) {
// Check if `f` is a TableFactor. If yes, then
// convert `this` to a TableFactor which is cheaper.
return std::make_shared<TableFactor>(tf->operator/(TableFactor(*this)));

} else if (auto dtf = std::dynamic_pointer_cast<DecisionTreeFactor>(f)) {
// If `f` is a DecisionTreeFactor, divide normally.
return std::make_shared<DecisionTreeFactor>(this->operator/(*dtf));

} else {
// Else, convert `f` to a DecisionTreeFactor so we can divide
return std::make_shared<DecisionTreeFactor>(
this->operator/(f->toDecisionTreeFactor()));
}
}

/* ************************************************************************ */
double DecisionTreeFactor::safe_div(const double& a, const double& b) {
// The use for safe_div is when we divide the product factor by the sum
Expand Down
18 changes: 14 additions & 4 deletions gtsam/discrete/DecisionTreeFactor.h
Original file line number Diff line number Diff line change
Expand Up @@ -184,26 +184,30 @@ namespace gtsam {
return apply(f, safe_div);
}

/// divide by DiscreteFactor::shared_ptr f (safely)
DiscreteFactor::shared_ptr operator/(
const DiscreteFactor::shared_ptr& f) const override;

/// Convert into a decision tree
DecisionTreeFactor toDecisionTreeFactor() const override { return *this; }

/// Create new factor by summing all values with the same separator values
shared_ptr sum(size_t nrFrontals) const {
DiscreteFactor::shared_ptr sum(size_t nrFrontals) const override {
return combine(nrFrontals, Ring::add);
}

/// Create new factor by summing all values with the same separator values
shared_ptr sum(const Ordering& keys) const {
DiscreteFactor::shared_ptr sum(const Ordering& keys) const override {
return combine(keys, Ring::add);
}

/// Create new factor by maximizing over all values with the same separator.
shared_ptr max(size_t nrFrontals) const {
DiscreteFactor::shared_ptr max(size_t nrFrontals) const override {
return combine(nrFrontals, Ring::max);
}

/// Create new factor by maximizing over all values with the same separator.
shared_ptr max(const Ordering& keys) const {
DiscreteFactor::shared_ptr max(const Ordering& keys) const override {
return combine(keys, Ring::max);
}

Expand Down Expand Up @@ -284,6 +288,12 @@ namespace gtsam {
*/
DecisionTreeFactor prune(size_t maxNrAssignments) const;

/**
* Get the number of non-zero values contained in this factor.
* It could be much smaller than `prod_{key}(cardinality(key))`.
*/
uint64_t nrValues() const override { return nrLeaves(); }

/// @}
/// @name Wrapper support
/// @{
Expand Down
19 changes: 10 additions & 9 deletions gtsam/discrete/DiscreteConditional.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -24,13 +24,13 @@
#include <gtsam/hybrid/HybridValues.h>

#include <algorithm>
#include <cassert>
#include <random>
#include <set>
#include <stdexcept>
#include <string>
#include <utility>
#include <vector>
#include <cassert>

using namespace std;
using std::pair;
Expand All @@ -44,8 +44,9 @@ template class GTSAM_EXPORT

/* ************************************************************************** */
DiscreteConditional::DiscreteConditional(const size_t nrFrontals,
const DecisionTreeFactor& f)
: BaseFactor(f / (*f.sum(nrFrontals))), BaseConditional(nrFrontals) {}
const DiscreteFactor& f)
: BaseFactor((f / f.sum(nrFrontals))->toDecisionTreeFactor()),
BaseConditional(nrFrontals) {}

/* ************************************************************************** */
DiscreteConditional::DiscreteConditional(size_t nrFrontals,
Expand Down Expand Up @@ -150,11 +151,11 @@ void DiscreteConditional::print(const string& s,
/* ************************************************************************** */
bool DiscreteConditional::equals(const DiscreteFactor& other,
double tol) const {
if (!dynamic_cast<const DecisionTreeFactor*>(&other)) {
if (!dynamic_cast<const BaseFactor*>(&other)) {
return false;
} else {
const DecisionTreeFactor& f(static_cast<const DecisionTreeFactor&>(other));
return DecisionTreeFactor::equals(f, tol);
const BaseFactor& f(static_cast<const BaseFactor&>(other));
return BaseFactor::equals(f, tol);
}
}

Expand Down Expand Up @@ -375,7 +376,7 @@ std::string DiscreteConditional::markdown(const KeyFormatter& keyFormatter,
ss << "*\n" << std::endl;
if (nrParents() == 0) {
// We have no parents, call factor method.
ss << DecisionTreeFactor::markdown(keyFormatter, names);
ss << BaseFactor::markdown(keyFormatter, names);
return ss.str();
}

Expand Down Expand Up @@ -427,7 +428,7 @@ string DiscreteConditional::html(const KeyFormatter& keyFormatter,
ss << "</i></p>\n";
if (nrParents() == 0) {
// We have no parents, call factor method.
ss << DecisionTreeFactor::html(keyFormatter, names);
ss << BaseFactor::html(keyFormatter, names);
return ss.str();
}

Expand Down Expand Up @@ -475,7 +476,7 @@ string DiscreteConditional::html(const KeyFormatter& keyFormatter,

/* ************************************************************************* */
double DiscreteConditional::evaluate(const HybridValues& x) const {
return this->evaluate(x.discrete());
return this->operator()(x.discrete());
}

/* ************************************************************************* */
Expand Down
2 changes: 1 addition & 1 deletion gtsam/discrete/DiscreteConditional.h
Original file line number Diff line number Diff line change
Expand Up @@ -54,7 +54,7 @@ class GTSAM_EXPORT DiscreteConditional
DiscreteConditional() {}

/// Construct from factor, taking the first `nFrontals` keys as frontals.
DiscreteConditional(size_t nFrontals, const DecisionTreeFactor& f);
DiscreteConditional(size_t nFrontals, const DiscreteFactor& f);

/**
* Construct from DiscreteKeys and AlgebraicDecisionTree, taking the first
Expand Down
23 changes: 23 additions & 0 deletions gtsam/discrete/DiscreteFactor.h
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,7 @@
#include <gtsam/discrete/AlgebraicDecisionTree.h>
#include <gtsam/discrete/DiscreteValues.h>
#include <gtsam/inference/Factor.h>
#include <gtsam/inference/Ordering.h>

#include <string>
namespace gtsam {
Expand Down Expand Up @@ -139,8 +140,30 @@ class GTSAM_EXPORT DiscreteFactor : public Factor {
virtual DiscreteFactor::shared_ptr multiply(
const DiscreteFactor::shared_ptr& df) const = 0;

/// divide by DiscreteFactor::shared_ptr f (safely)
virtual DiscreteFactor::shared_ptr operator/(
const DiscreteFactor::shared_ptr& df) const = 0;

virtual DecisionTreeFactor toDecisionTreeFactor() const = 0;

/// Create new factor by summing all values with the same separator values
virtual DiscreteFactor::shared_ptr sum(size_t nrFrontals) const = 0;

/// Create new factor by summing all values with the same separator values
virtual DiscreteFactor::shared_ptr sum(const Ordering& keys) const = 0;

/// Create new factor by maximizing over all values with the same separator.
virtual DiscreteFactor::shared_ptr max(size_t nrFrontals) const = 0;

/// Create new factor by maximizing over all values with the same separator.
virtual DiscreteFactor::shared_ptr max(const Ordering& keys) const = 0;

/**
* Get the number of non-zero values contained in this factor.
* It could be much smaller than `prod_{key}(cardinality(key))`.
*/
virtual uint64_t nrValues() const = 0;

/// @}
/// @name Wrapper support
/// @{
Expand Down
35 changes: 18 additions & 17 deletions gtsam/discrete/DiscreteFactorGraph.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -64,7 +64,7 @@ namespace gtsam {
}

/* ************************************************************************ */
DecisionTreeFactor DiscreteFactorGraph::product() const {
DiscreteFactor::shared_ptr DiscreteFactorGraph::product() const {
DiscreteFactor::shared_ptr result;
for (auto it = this->begin(); it != this->end(); ++it) {
if (*it) {
Expand All @@ -76,7 +76,7 @@ namespace gtsam {
}
}
}
return result->toDecisionTreeFactor();
return result;
}

/* ************************************************************************ */
Expand Down Expand Up @@ -122,20 +122,20 @@ namespace gtsam {
* @brief Multiply all the `factors`.
*
* @param factors The factors to multiply as a DiscreteFactorGraph.
* @return DecisionTreeFactor
* @return DiscreteFactor::shared_ptr
*/
static DecisionTreeFactor DiscreteProduct(
static DiscreteFactor::shared_ptr DiscreteProduct(
const DiscreteFactorGraph& factors) {
// PRODUCT: multiply all factors
gttic(product);
DecisionTreeFactor product = factors.product();
DiscreteFactor::shared_ptr product = factors.product();
gttoc(product);

// Max over all the potentials by pretending all keys are frontal:
auto denominator = product.max(product.size());
auto denominator = product->max(product->size());

// Normalize the product factor to prevent underflow.
product = product / (*denominator);
product = product->operator/(denominator);

return product;
}
Expand All @@ -145,25 +145,25 @@ namespace gtsam {
std::pair<DiscreteConditional::shared_ptr, DiscreteFactor::shared_ptr> //
EliminateForMPE(const DiscreteFactorGraph& factors,
const Ordering& frontalKeys) {
DecisionTreeFactor product = DiscreteProduct(factors);
DiscreteFactor::shared_ptr product = DiscreteProduct(factors);

// max out frontals, this is the factor on the separator
gttic(max);
DecisionTreeFactor::shared_ptr max = product.max(frontalKeys);
DiscreteFactor::shared_ptr max = product->max(frontalKeys);
gttoc(max);

// Ordering keys for the conditional so that frontalKeys are really in front
DiscreteKeys orderedKeys;
for (auto&& key : frontalKeys)
orderedKeys.emplace_back(key, product.cardinality(key));
orderedKeys.emplace_back(key, product->cardinality(key));
for (auto&& key : max->keys())
orderedKeys.emplace_back(key, product.cardinality(key));
orderedKeys.emplace_back(key, product->cardinality(key));

// Make lookup with product
gttic(lookup);
size_t nrFrontals = frontalKeys.size();
auto lookup =
std::make_shared<DiscreteLookupTable>(nrFrontals, orderedKeys, product);
auto lookup = std::make_shared<DiscreteLookupTable>(
nrFrontals, orderedKeys, product->toDecisionTreeFactor());
gttoc(lookup);

return {std::dynamic_pointer_cast<DiscreteConditional>(lookup), max};
Expand Down Expand Up @@ -223,11 +223,11 @@ namespace gtsam {
std::pair<DiscreteConditional::shared_ptr, DiscreteFactor::shared_ptr> //
EliminateDiscrete(const DiscreteFactorGraph& factors,
const Ordering& frontalKeys) {
DecisionTreeFactor product = DiscreteProduct(factors);
DiscreteFactor::shared_ptr product = DiscreteProduct(factors);

// sum out frontals, this is the factor on the separator
gttic(sum);
DecisionTreeFactor::shared_ptr sum = product.sum(frontalKeys);
DiscreteFactor::shared_ptr sum = product->sum(frontalKeys);
gttoc(sum);

// Ordering keys for the conditional so that frontalKeys are really in front
Expand All @@ -239,8 +239,9 @@ namespace gtsam {

// now divide product/sum to get conditional
gttic(divide);
auto conditional =
std::make_shared<DiscreteConditional>(product, *sum, orderedKeys);
auto conditional = std::make_shared<DiscreteConditional>(
product->toDecisionTreeFactor(), sum->toDecisionTreeFactor(),
orderedKeys);
gttoc(divide);

return {conditional, sum};
Expand Down
3 changes: 2 additions & 1 deletion gtsam/discrete/DiscreteFactorGraph.h
Original file line number Diff line number Diff line change
Expand Up @@ -134,6 +134,7 @@ class GTSAM_EXPORT DiscreteFactorGraph

/// @}

//TODO(Varun): Make compatible with TableFactor
/** Add a decision-tree factor */
template <typename... Args>
void add(Args&&... args) {
Expand All @@ -147,7 +148,7 @@ class GTSAM_EXPORT DiscreteFactorGraph
DiscreteKeys discreteKeys() const;

/** return product of all factors as a single factor */
DecisionTreeFactor product() const;
DiscreteFactor::shared_ptr product() const;

/**
* Evaluates the factor graph given values, returns the joint probability of
Expand Down
13 changes: 13 additions & 0 deletions gtsam/discrete/DiscreteLookupDAG.h
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,7 @@
#pragma once

#include <gtsam/discrete/DiscreteDistribution.h>
#include <gtsam/discrete/TableFactor.h>
#include <gtsam/inference/BayesNet.h>
#include <gtsam/inference/FactorGraph.h>

Expand Down Expand Up @@ -54,6 +55,18 @@ class GTSAM_EXPORT DiscreteLookupTable : public DiscreteConditional {
const ADT& potentials)
: DiscreteConditional(nFrontals, keys, potentials) {}

/**
* @brief Construct a new Discrete Lookup Table object
*
* @param nFrontals number of frontal variables
* @param keys a sorted list of gtsam::Keys
* @param potentials Discrete potentials as a TableFactor.
*/
DiscreteLookupTable(size_t nFrontals, const DiscreteKeys& keys,
const TableFactor& potentials)
: DiscreteConditional(nFrontals, keys,
potentials.toDecisionTreeFactor()) {}

/// GTSAM-style print
void print(
const std::string& s = "Discrete Lookup Table: ",
Expand Down
14 changes: 14 additions & 0 deletions gtsam/discrete/TableFactor.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -280,6 +280,20 @@ DiscreteFactor::shared_ptr TableFactor::multiply(
return result;
}

/* ************************************************************************ */
DiscreteFactor::shared_ptr TableFactor::operator/(
const DiscreteFactor::shared_ptr& f) const {
if (auto tf = std::dynamic_pointer_cast<TableFactor>(f)) {
return std::make_shared<TableFactor>(this->operator/(*tf));
} else if (auto dtf = std::dynamic_pointer_cast<DecisionTreeFactor>(f)) {
return std::make_shared<TableFactor>(
this->operator/(TableFactor(f->discreteKeys(), *dtf)));
} else {
TableFactor divisor(f->toDecisionTreeFactor());
return std::make_shared<TableFactor>(this->operator/(divisor));
}
}

/* ************************************************************************ */
DecisionTreeFactor TableFactor::toDecisionTreeFactor() const {
DiscreteKeys dkeys = discreteKeys();
Expand Down
Loading

0 comments on commit 82d0ebc

Please sign in to comment.