Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Discrete Elimination Refactor #1919

Merged
merged 66 commits into from
Jan 7, 2025
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
66 commits
Select commit Hold shift + click to select a range
d1d440a
add nrValues method
varunagrawal Dec 7, 2024
a68da21
operator* version which accepts DiscreteFactor
varunagrawal Dec 7, 2024
a09b77e
return DiscreteFactor shared_ptr as leftover from elimination
varunagrawal Dec 7, 2024
27bbce1
generalize DiscreteFactorGraph::product to DiscreteFactor
varunagrawal Dec 7, 2024
84e4194
make normalization code common
varunagrawal Dec 7, 2024
4dac37c
make sum and max DiscreteFactor methods
varunagrawal Dec 7, 2024
6c45467
add timing info
varunagrawal Dec 7, 2024
b0ad350
add note about toDecisionTreeFactor
varunagrawal Dec 7, 2024
306a3ba
kill toDecisionTreeFactor to force rethink
varunagrawal Dec 7, 2024
2cd2ab0
DiscreteDistribution from TableFactor
varunagrawal Dec 7, 2024
9f88a36
make evaluate use the Assignment<Key> base class
varunagrawal Dec 7, 2024
2a3b5e6
use Assignment<Key> for evaluate since it is the base class
varunagrawal Dec 7, 2024
fff8458
remove TableFactor constructor in DiscreteDistribution
varunagrawal Dec 8, 2024
295b965
use Assignment<Key> since it is a base class
varunagrawal Dec 8, 2024
261038f
fix DiscreteConditional constructor
varunagrawal Dec 8, 2024
20d6d09
use DiscreteFactor everywhere in DiscreteFactorGraph.cpp
varunagrawal Dec 8, 2024
32b6bc0
update DiscreteConditional
varunagrawal Dec 8, 2024
38563da
Revert "kill toDecisionTreeFactor to force rethink"
varunagrawal Dec 8, 2024
9633ad1
make DiscreteConditional::likelihood match the declaration
varunagrawal Dec 8, 2024
0b3477f
get different classes to play nicely
varunagrawal Dec 8, 2024
1d79188
compiles
varunagrawal Dec 8, 2024
7757851
timing
varunagrawal Dec 8, 2024
9844a55
move evaluate and operator() next to each other
varunagrawal Dec 8, 2024
aa25ccf
implement evaluate in DiscreteFactor
varunagrawal Dec 8, 2024
90d7e21
change from DiscreteValues to Assignment<Key>
varunagrawal Dec 8, 2024
6665659
use BaseFactor instead of DecisionTreeFactor
varunagrawal Dec 8, 2024
f9a9801
Merge branch 'ring' into discrete-elimination-refactor
varunagrawal Dec 8, 2024
e6b6528
common definitions of Unary, UnaryAssignment and Binary
varunagrawal Dec 8, 2024
f85284a
some cleanup based on previous commit
varunagrawal Dec 8, 2024
5e86f7e
remove previously added code
varunagrawal Dec 8, 2024
1c14a56
revert changes to make code generic
varunagrawal Dec 8, 2024
b325150
revert DiscreteFactorGraph::product
varunagrawal Dec 8, 2024
0afc198
revert some DiscreteFactorGraph changes
varunagrawal Dec 8, 2024
975fe62
add methods in gtsam_unstable
varunagrawal Dec 8, 2024
fc2d33f
add division with DiscreteFactor::shared_ptr for convenience
varunagrawal Dec 8, 2024
2c02efc
fix tests
varunagrawal Dec 8, 2024
360598d
undo uncomment
varunagrawal Dec 8, 2024
853241c
add evaluate to DiscreteConditional
varunagrawal Dec 8, 2024
199c0a0
keep using DecisionTreeFactor for DiscreteConditional
varunagrawal Dec 8, 2024
214bf4e
more fixes
varunagrawal Dec 8, 2024
0de114f
Merge branch 'develop' into discrete-elimination-refactor
varunagrawal Dec 9, 2024
e46cd54
TableFactor cleanup
varunagrawal Dec 9, 2024
52c8034
add division by DiscreteFactor in TableFactor
varunagrawal Dec 9, 2024
e0e833c
cleanup
varunagrawal Dec 9, 2024
84627c0
fix error
varunagrawal Dec 9, 2024
cc4e9cb
Merge branch 'develop' into discrete-elimination-refactor
varunagrawal Dec 10, 2024
0b3f058
Merge branch 'develop' into discrete-elimination-refactor
varunagrawal Dec 11, 2024
22d11d7
don't print timing info by default
varunagrawal Dec 11, 2024
90d8486
Merge branch 'develop' into discrete-elimination-refactor
varunagrawal Dec 30, 2024
d3901be
Merge branch 'develop' into discrete-elimination-refactor
varunagrawal Jan 5, 2025
5fa04d7
small improvements
varunagrawal Jan 5, 2025
268290d
multiply method for DiscreteFactor
varunagrawal Jan 5, 2025
2f09e86
remove override from definition
varunagrawal Jan 5, 2025
5e9c130
Merge branch 'discrete-multiply' into discrete-elimination-refactor
varunagrawal Jan 5, 2025
b5128b2
use DecisionTreeFactor version of sum and max where not available
varunagrawal Jan 6, 2025
4ebca71
divide operator for DiscreteFactor::shared_ptr
varunagrawal Jan 6, 2025
fb1d52a
fix constructor
varunagrawal Jan 6, 2025
e9822a7
update DiscreteFactorGraph to use DiscreteFactor::shared_ptr for elim…
varunagrawal Jan 6, 2025
2f8c8dd
update tests
varunagrawal Jan 6, 2025
2434e24
undo print change in DiscreteLookupTable
varunagrawal Jan 6, 2025
ab2fe37
Merge branch 'discrete-multiply' into discrete-elimination-refactor
varunagrawal Jan 6, 2025
7561da4
move operator/ to Constraint.h
varunagrawal Jan 6, 2025
ff5371f
move sum, max and nrValues to Constraint class as well
varunagrawal Jan 6, 2025
f932945
check pointer casts
varunagrawal Jan 6, 2025
f8dedb5
use DiscreteFactor for DiscreteConditional constructor
varunagrawal Jan 6, 2025
c754f9b
add comments
varunagrawal Jan 6, 2025
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
19 changes: 19 additions & 0 deletions gtsam/discrete/DecisionTreeFactor.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -77,6 +77,25 @@ namespace gtsam {
return result;
}

/* ************************************************************************ */
DiscreteFactor::shared_ptr DecisionTreeFactor::operator/(
const DiscreteFactor::shared_ptr& f) const {
if (auto tf = std::dynamic_pointer_cast<TableFactor>(f)) {
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Add comments

// Check if `f` is a TableFactor. If yes, then
// convert `this` to a TableFactor which is cheaper.
return std::make_shared<TableFactor>(tf->operator/(TableFactor(*this)));

} else if (auto dtf = std::dynamic_pointer_cast<DecisionTreeFactor>(f)) {
// If `f` is a DecisionTreeFactor, divide normally.
return std::make_shared<DecisionTreeFactor>(this->operator/(*dtf));

} else {
// Else, convert `f` to a DecisionTreeFactor so we can divide
return std::make_shared<DecisionTreeFactor>(
this->operator/(f->toDecisionTreeFactor()));
}
}

/* ************************************************************************ */
double DecisionTreeFactor::safe_div(const double& a, const double& b) {
// The use for safe_div is when we divide the product factor by the sum
Expand Down
18 changes: 14 additions & 4 deletions gtsam/discrete/DecisionTreeFactor.h
Original file line number Diff line number Diff line change
Expand Up @@ -164,26 +164,30 @@ namespace gtsam {
return apply(f, safe_div);
}

/// divide by DiscreteFactor::shared_ptr f (safely)
DiscreteFactor::shared_ptr operator/(
const DiscreteFactor::shared_ptr& f) const override;

/// Convert into a decision tree
DecisionTreeFactor toDecisionTreeFactor() const override { return *this; }

/// Create new factor by summing all values with the same separator values
shared_ptr sum(size_t nrFrontals) const {
DiscreteFactor::shared_ptr sum(size_t nrFrontals) const override {
return combine(nrFrontals, Ring::add);
}

/// Create new factor by summing all values with the same separator values
shared_ptr sum(const Ordering& keys) const {
DiscreteFactor::shared_ptr sum(const Ordering& keys) const override {
return combine(keys, Ring::add);
}

/// Create new factor by maximizing over all values with the same separator.
shared_ptr max(size_t nrFrontals) const {
DiscreteFactor::shared_ptr max(size_t nrFrontals) const override {
return combine(nrFrontals, Ring::max);
}

/// Create new factor by maximizing over all values with the same separator.
shared_ptr max(const Ordering& keys) const {
DiscreteFactor::shared_ptr max(const Ordering& keys) const override {
return combine(keys, Ring::max);
}

Expand Down Expand Up @@ -264,6 +268,12 @@ namespace gtsam {
*/
DecisionTreeFactor prune(size_t maxNrAssignments) const;

/**
* Get the number of non-zero values contained in this factor.
* It could be much smaller than `prod_{key}(cardinality(key))`.
*/
uint64_t nrValues() const override { return nrLeaves(); }

/// @}
/// @name Wrapper support
/// @{
Expand Down
19 changes: 10 additions & 9 deletions gtsam/discrete/DiscreteConditional.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -24,13 +24,13 @@
#include <gtsam/hybrid/HybridValues.h>

#include <algorithm>
#include <cassert>
#include <random>
#include <set>
#include <stdexcept>
#include <string>
#include <utility>
#include <vector>
#include <cassert>

using namespace std;
using std::pair;
Expand All @@ -44,8 +44,9 @@ template class GTSAM_EXPORT

/* ************************************************************************** */
DiscreteConditional::DiscreteConditional(const size_t nrFrontals,
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Can this now be a DiscreetFactor& ?

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yes it can!

const DecisionTreeFactor& f)
: BaseFactor(f / (*f.sum(nrFrontals))), BaseConditional(nrFrontals) {}
const DiscreteFactor& f)
: BaseFactor((f / f.sum(nrFrontals))->toDecisionTreeFactor()),
BaseConditional(nrFrontals) {}

/* ************************************************************************** */
DiscreteConditional::DiscreteConditional(size_t nrFrontals,
Expand Down Expand Up @@ -150,11 +151,11 @@ void DiscreteConditional::print(const string& s,
/* ************************************************************************** */
bool DiscreteConditional::equals(const DiscreteFactor& other,
double tol) const {
if (!dynamic_cast<const DecisionTreeFactor*>(&other)) {
if (!dynamic_cast<const BaseFactor*>(&other)) {
return false;
} else {
const DecisionTreeFactor& f(static_cast<const DecisionTreeFactor&>(other));
return DecisionTreeFactor::equals(f, tol);
const BaseFactor& f(static_cast<const BaseFactor&>(other));
return BaseFactor::equals(f, tol);
}
}

Expand Down Expand Up @@ -375,7 +376,7 @@ std::string DiscreteConditional::markdown(const KeyFormatter& keyFormatter,
ss << "*\n" << std::endl;
if (nrParents() == 0) {
// We have no parents, call factor method.
ss << DecisionTreeFactor::markdown(keyFormatter, names);
ss << BaseFactor::markdown(keyFormatter, names);
return ss.str();
}

Expand Down Expand Up @@ -427,7 +428,7 @@ string DiscreteConditional::html(const KeyFormatter& keyFormatter,
ss << "</i></p>\n";
if (nrParents() == 0) {
// We have no parents, call factor method.
ss << DecisionTreeFactor::html(keyFormatter, names);
ss << BaseFactor::html(keyFormatter, names);
return ss.str();
}

Expand Down Expand Up @@ -475,7 +476,7 @@ string DiscreteConditional::html(const KeyFormatter& keyFormatter,

/* ************************************************************************* */
double DiscreteConditional::evaluate(const HybridValues& x) const {
return this->evaluate(x.discrete());
return this->operator()(x.discrete());
}

/* ************************************************************************* */
Expand Down
2 changes: 1 addition & 1 deletion gtsam/discrete/DiscreteConditional.h
Original file line number Diff line number Diff line change
Expand Up @@ -54,7 +54,7 @@ class GTSAM_EXPORT DiscreteConditional
DiscreteConditional() {}

/// Construct from factor, taking the first `nFrontals` keys as frontals.
DiscreteConditional(size_t nFrontals, const DecisionTreeFactor& f);
DiscreteConditional(size_t nFrontals, const DiscreteFactor& f);

/**
* Construct from DiscreteKeys and AlgebraicDecisionTree, taking the first
Expand Down
23 changes: 23 additions & 0 deletions gtsam/discrete/DiscreteFactor.h
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,7 @@
#include <gtsam/discrete/AlgebraicDecisionTree.h>
#include <gtsam/discrete/DiscreteValues.h>
#include <gtsam/inference/Factor.h>
#include <gtsam/inference/Ordering.h>

#include <string>
namespace gtsam {
Expand Down Expand Up @@ -139,8 +140,30 @@ class GTSAM_EXPORT DiscreteFactor : public Factor {
virtual DiscreteFactor::shared_ptr multiply(
const DiscreteFactor::shared_ptr& df) const = 0;

/// divide by DiscreteFactor::shared_ptr f (safely)
virtual DiscreteFactor::shared_ptr operator/(
const DiscreteFactor::shared_ptr& df) const = 0;

virtual DecisionTreeFactor toDecisionTreeFactor() const = 0;

/// Create new factor by summing all values with the same separator values
virtual DiscreteFactor::shared_ptr sum(size_t nrFrontals) const = 0;

/// Create new factor by summing all values with the same separator values
virtual DiscreteFactor::shared_ptr sum(const Ordering& keys) const = 0;

/// Create new factor by maximizing over all values with the same separator.
virtual DiscreteFactor::shared_ptr max(size_t nrFrontals) const = 0;

/// Create new factor by maximizing over all values with the same separator.
virtual DiscreteFactor::shared_ptr max(const Ordering& keys) const = 0;

/**
* Get the number of non-zero values contained in this factor.
* It could be much smaller than `prod_{key}(cardinality(key))`.
*/
virtual uint64_t nrValues() const = 0;

/// @}
/// @name Wrapper support
/// @{
Expand Down
35 changes: 18 additions & 17 deletions gtsam/discrete/DiscreteFactorGraph.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -64,7 +64,7 @@ namespace gtsam {
}

/* ************************************************************************ */
DecisionTreeFactor DiscreteFactorGraph::product() const {
DiscreteFactor::shared_ptr DiscreteFactorGraph::product() const {
DiscreteFactor::shared_ptr result;
for (auto it = this->begin(); it != this->end(); ++it) {
if (*it) {
Expand All @@ -76,7 +76,7 @@ namespace gtsam {
}
}
}
return result->toDecisionTreeFactor();
return result;
}

/* ************************************************************************ */
Expand Down Expand Up @@ -122,20 +122,20 @@ namespace gtsam {
* @brief Multiply all the `factors`.
*
* @param factors The factors to multiply as a DiscreteFactorGraph.
* @return DecisionTreeFactor
* @return DiscreteFactor::shared_ptr
*/
static DecisionTreeFactor DiscreteProduct(
static DiscreteFactor::shared_ptr DiscreteProduct(
const DiscreteFactorGraph& factors) {
// PRODUCT: multiply all factors
gttic(product);
DecisionTreeFactor product = factors.product();
DiscreteFactor::shared_ptr product = factors.product();
gttoc(product);

// Max over all the potentials by pretending all keys are frontal:
auto denominator = product.max(product.size());
auto denominator = product->max(product->size());

// Normalize the product factor to prevent underflow.
product = product / (*denominator);
product = product->operator/(denominator);

return product;
}
Expand All @@ -145,25 +145,25 @@ namespace gtsam {
std::pair<DiscreteConditional::shared_ptr, DiscreteFactor::shared_ptr> //
EliminateForMPE(const DiscreteFactorGraph& factors,
const Ordering& frontalKeys) {
DecisionTreeFactor product = DiscreteProduct(factors);
DiscreteFactor::shared_ptr product = DiscreteProduct(factors);

// max out frontals, this is the factor on the separator
gttic(max);
DecisionTreeFactor::shared_ptr max = product.max(frontalKeys);
DiscreteFactor::shared_ptr max = product->max(frontalKeys);
gttoc(max);

// Ordering keys for the conditional so that frontalKeys are really in front
DiscreteKeys orderedKeys;
for (auto&& key : frontalKeys)
orderedKeys.emplace_back(key, product.cardinality(key));
orderedKeys.emplace_back(key, product->cardinality(key));
for (auto&& key : max->keys())
orderedKeys.emplace_back(key, product.cardinality(key));
orderedKeys.emplace_back(key, product->cardinality(key));

// Make lookup with product
gttic(lookup);
size_t nrFrontals = frontalKeys.size();
auto lookup =
std::make_shared<DiscreteLookupTable>(nrFrontals, orderedKeys, product);
auto lookup = std::make_shared<DiscreteLookupTable>(
nrFrontals, orderedKeys, product->toDecisionTreeFactor());
gttoc(lookup);

return {std::dynamic_pointer_cast<DiscreteConditional>(lookup), max};
Expand Down Expand Up @@ -223,11 +223,11 @@ namespace gtsam {
std::pair<DiscreteConditional::shared_ptr, DiscreteFactor::shared_ptr> //
EliminateDiscrete(const DiscreteFactorGraph& factors,
const Ordering& frontalKeys) {
DecisionTreeFactor product = DiscreteProduct(factors);
DiscreteFactor::shared_ptr product = DiscreteProduct(factors);

// sum out frontals, this is the factor on the separator
gttic(sum);
DecisionTreeFactor::shared_ptr sum = product.sum(frontalKeys);
DiscreteFactor::shared_ptr sum = product->sum(frontalKeys);
gttoc(sum);

// Ordering keys for the conditional so that frontalKeys are really in front
Expand All @@ -239,8 +239,9 @@ namespace gtsam {

// now divide product/sum to get conditional
gttic(divide);
auto conditional =
std::make_shared<DiscreteConditional>(product, *sum, orderedKeys);
auto conditional = std::make_shared<DiscreteConditional>(
product->toDecisionTreeFactor(), sum->toDecisionTreeFactor(),
orderedKeys);
gttoc(divide);

return {conditional, sum};
Expand Down
3 changes: 2 additions & 1 deletion gtsam/discrete/DiscreteFactorGraph.h
Original file line number Diff line number Diff line change
Expand Up @@ -134,6 +134,7 @@ class GTSAM_EXPORT DiscreteFactorGraph

/// @}

//TODO(Varun): Make compatible with TableFactor
/** Add a decision-tree factor */
template <typename... Args>
void add(Args&&... args) {
Expand All @@ -147,7 +148,7 @@ class GTSAM_EXPORT DiscreteFactorGraph
DiscreteKeys discreteKeys() const;

/** return product of all factors as a single factor */
DecisionTreeFactor product() const;
DiscreteFactor::shared_ptr product() const;

/**
* Evaluates the factor graph given values, returns the joint probability of
Expand Down
13 changes: 13 additions & 0 deletions gtsam/discrete/DiscreteLookupDAG.h
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,7 @@
#pragma once

#include <gtsam/discrete/DiscreteDistribution.h>
#include <gtsam/discrete/TableFactor.h>
#include <gtsam/inference/BayesNet.h>
#include <gtsam/inference/FactorGraph.h>

Expand Down Expand Up @@ -54,6 +55,18 @@ class GTSAM_EXPORT DiscreteLookupTable : public DiscreteConditional {
const ADT& potentials)
: DiscreteConditional(nFrontals, keys, potentials) {}

/**
* @brief Construct a new Discrete Lookup Table object
*
* @param nFrontals number of frontal variables
* @param keys a sorted list of gtsam::Keys
* @param potentials Discrete potentials as a TableFactor.
*/
DiscreteLookupTable(size_t nFrontals, const DiscreteKeys& keys,
const TableFactor& potentials)
: DiscreteConditional(nFrontals, keys,
potentials.toDecisionTreeFactor()) {}

/// GTSAM-style print
void print(
const std::string& s = "Discrete Lookup Table: ",
Expand Down
14 changes: 14 additions & 0 deletions gtsam/discrete/TableFactor.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -270,6 +270,20 @@ DiscreteFactor::shared_ptr TableFactor::multiply(
return result;
}

/* ************************************************************************ */
DiscreteFactor::shared_ptr TableFactor::operator/(
const DiscreteFactor::shared_ptr& f) const {
if (auto tf = std::dynamic_pointer_cast<TableFactor>(f)) {
return std::make_shared<TableFactor>(this->operator/(*tf));
} else if (auto dtf = std::dynamic_pointer_cast<DecisionTreeFactor>(f)) {
return std::make_shared<TableFactor>(
this->operator/(TableFactor(f->discreteKeys(), *dtf)));
} else {
TableFactor divisor(f->toDecisionTreeFactor());
return std::make_shared<TableFactor>(this->operator/(divisor));
}
}

/* ************************************************************************ */
DecisionTreeFactor TableFactor::toDecisionTreeFactor() const {
DiscreteKeys dkeys = discreteKeys();
Expand Down
Loading
Loading