Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Discrete Elimination Refactor #1919

Merged
merged 66 commits into from
Jan 7, 2025
Merged
Show file tree
Hide file tree
Changes from 29 commits
Commits
Show all changes
66 commits
Select commit Hold shift + click to select a range
d1d440a
add nrValues method
varunagrawal Dec 7, 2024
a68da21
operator* version which accepts DiscreteFactor
varunagrawal Dec 7, 2024
a09b77e
return DiscreteFactor shared_ptr as leftover from elimination
varunagrawal Dec 7, 2024
27bbce1
generalize DiscreteFactorGraph::product to DiscreteFactor
varunagrawal Dec 7, 2024
84e4194
make normalization code common
varunagrawal Dec 7, 2024
4dac37c
make sum and max DiscreteFactor methods
varunagrawal Dec 7, 2024
6c45467
add timing info
varunagrawal Dec 7, 2024
b0ad350
add note about toDecisionTreeFactor
varunagrawal Dec 7, 2024
306a3ba
kill toDecisionTreeFactor to force rethink
varunagrawal Dec 7, 2024
2cd2ab0
DiscreteDistribution from TableFactor
varunagrawal Dec 7, 2024
9f88a36
make evaluate use the Assignment<Key> base class
varunagrawal Dec 7, 2024
2a3b5e6
use Assignment<Key> for evaluate since it is the base class
varunagrawal Dec 7, 2024
fff8458
remove TableFactor constructor in DiscreteDistribution
varunagrawal Dec 8, 2024
295b965
use Assignment<Key> since it is a base class
varunagrawal Dec 8, 2024
261038f
fix DiscreteConditional constructor
varunagrawal Dec 8, 2024
20d6d09
use DiscreteFactor everywhere in DiscreteFactorGraph.cpp
varunagrawal Dec 8, 2024
32b6bc0
update DiscreteConditional
varunagrawal Dec 8, 2024
38563da
Revert "kill toDecisionTreeFactor to force rethink"
varunagrawal Dec 8, 2024
9633ad1
make DiscreteConditional::likelihood match the declaration
varunagrawal Dec 8, 2024
0b3477f
get different classes to play nicely
varunagrawal Dec 8, 2024
1d79188
compiles
varunagrawal Dec 8, 2024
7757851
timing
varunagrawal Dec 8, 2024
9844a55
move evaluate and operator() next to each other
varunagrawal Dec 8, 2024
aa25ccf
implement evaluate in DiscreteFactor
varunagrawal Dec 8, 2024
90d7e21
change from DiscreteValues to Assignment<Key>
varunagrawal Dec 8, 2024
6665659
use BaseFactor instead of DecisionTreeFactor
varunagrawal Dec 8, 2024
f9a9801
Merge branch 'ring' into discrete-elimination-refactor
varunagrawal Dec 8, 2024
e6b6528
common definitions of Unary, UnaryAssignment and Binary
varunagrawal Dec 8, 2024
f85284a
some cleanup based on previous commit
varunagrawal Dec 8, 2024
5e86f7e
remove previously added code
varunagrawal Dec 8, 2024
1c14a56
revert changes to make code generic
varunagrawal Dec 8, 2024
b325150
revert DiscreteFactorGraph::product
varunagrawal Dec 8, 2024
0afc198
revert some DiscreteFactorGraph changes
varunagrawal Dec 8, 2024
975fe62
add methods in gtsam_unstable
varunagrawal Dec 8, 2024
fc2d33f
add division with DiscreteFactor::shared_ptr for convenience
varunagrawal Dec 8, 2024
2c02efc
fix tests
varunagrawal Dec 8, 2024
360598d
undo uncomment
varunagrawal Dec 8, 2024
853241c
add evaluate to DiscreteConditional
varunagrawal Dec 8, 2024
199c0a0
keep using DecisionTreeFactor for DiscreteConditional
varunagrawal Dec 8, 2024
214bf4e
more fixes
varunagrawal Dec 8, 2024
0de114f
Merge branch 'develop' into discrete-elimination-refactor
varunagrawal Dec 9, 2024
e46cd54
TableFactor cleanup
varunagrawal Dec 9, 2024
52c8034
add division by DiscreteFactor in TableFactor
varunagrawal Dec 9, 2024
e0e833c
cleanup
varunagrawal Dec 9, 2024
84627c0
fix error
varunagrawal Dec 9, 2024
cc4e9cb
Merge branch 'develop' into discrete-elimination-refactor
varunagrawal Dec 10, 2024
0b3f058
Merge branch 'develop' into discrete-elimination-refactor
varunagrawal Dec 11, 2024
22d11d7
don't print timing info by default
varunagrawal Dec 11, 2024
90d8486
Merge branch 'develop' into discrete-elimination-refactor
varunagrawal Dec 30, 2024
d3901be
Merge branch 'develop' into discrete-elimination-refactor
varunagrawal Jan 5, 2025
5fa04d7
small improvements
varunagrawal Jan 5, 2025
268290d
multiply method for DiscreteFactor
varunagrawal Jan 5, 2025
2f09e86
remove override from definition
varunagrawal Jan 5, 2025
5e9c130
Merge branch 'discrete-multiply' into discrete-elimination-refactor
varunagrawal Jan 5, 2025
b5128b2
use DecisionTreeFactor version of sum and max where not available
varunagrawal Jan 6, 2025
4ebca71
divide operator for DiscreteFactor::shared_ptr
varunagrawal Jan 6, 2025
fb1d52a
fix constructor
varunagrawal Jan 6, 2025
e9822a7
update DiscreteFactorGraph to use DiscreteFactor::shared_ptr for elim…
varunagrawal Jan 6, 2025
2f8c8dd
update tests
varunagrawal Jan 6, 2025
2434e24
undo print change in DiscreteLookupTable
varunagrawal Jan 6, 2025
ab2fe37
Merge branch 'discrete-multiply' into discrete-elimination-refactor
varunagrawal Jan 6, 2025
7561da4
move operator/ to Constraint.h
varunagrawal Jan 6, 2025
ff5371f
move sum, max and nrValues to Constraint class as well
varunagrawal Jan 6, 2025
f932945
check pointer casts
varunagrawal Jan 6, 2025
f8dedb5
use DiscreteFactor for DiscreteConditional constructor
varunagrawal Jan 6, 2025
c754f9b
add comments
varunagrawal Jan 6, 2025
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
23 changes: 17 additions & 6 deletions gtsam/discrete/DecisionTreeFactor.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -83,15 +83,26 @@ namespace gtsam {
}

/* ************************************************************************ */
DecisionTreeFactor DecisionTreeFactor::apply(ADT::Unary op) const {
DiscreteFactor::shared_ptr DecisionTreeFactor::operator*(
const DiscreteFactor::shared_ptr& f) const {
if (auto derived = std::dynamic_pointer_cast<DecisionTreeFactor>(f)) {
return std::make_shared<DecisionTreeFactor>(this->operator*(*derived));
} else {
throw std::runtime_error(
"Cannot convert DiscreteFactor to DecisionTreeFactor");
}
}

/* ************************************************************************ */
DecisionTreeFactor DecisionTreeFactor::apply(Unary op) const {
// apply operand
ADT result = ADT::apply(op);
// Make a new factor
return DecisionTreeFactor(discreteKeys(), result);
}

/* ************************************************************************ */
DecisionTreeFactor DecisionTreeFactor::apply(ADT::UnaryAssignment op) const {
DecisionTreeFactor DecisionTreeFactor::apply(UnaryAssignment op) const {
// apply operand
ADT result = ADT::apply(op);
// Make a new factor
Expand All @@ -100,7 +111,7 @@ namespace gtsam {

/* ************************************************************************ */
DecisionTreeFactor DecisionTreeFactor::apply(const DecisionTreeFactor& f,
ADT::Binary op) const {
Binary op) const {
map<Key, size_t> cs; // new cardinalities
// make unique key-cardinality map
for (Key j : keys()) cs[j] = cardinality(j);
Expand All @@ -118,8 +129,8 @@ namespace gtsam {
}

/* ************************************************************************ */
DecisionTreeFactor::shared_ptr DecisionTreeFactor::combine(
size_t nrFrontals, ADT::Binary op) const {
DecisionTreeFactor::shared_ptr DecisionTreeFactor::combine(size_t nrFrontals,
Binary op) const {
if (nrFrontals > size()) {
throw invalid_argument(
"DecisionTreeFactor::combine: invalid number of frontal "
Expand All @@ -146,7 +157,7 @@ namespace gtsam {

/* ************************************************************************ */
DecisionTreeFactor::shared_ptr DecisionTreeFactor::combine(
const Ordering& frontalKeys, ADT::Binary op) const {
const Ordering& frontalKeys, Binary op) const {
if (frontalKeys.size() > size()) {
throw invalid_argument(
"DecisionTreeFactor::combine: invalid number of frontal "
Expand Down
54 changes: 37 additions & 17 deletions gtsam/discrete/DecisionTreeFactor.h
Original file line number Diff line number Diff line change
Expand Up @@ -51,6 +51,11 @@ namespace gtsam {
typedef std::shared_ptr<DecisionTreeFactor> shared_ptr;
typedef AlgebraicDecisionTree<Key> ADT;

// Needed since we have definitions in both DiscreteFactor and DecisionTree
using Base::Binary;
using Base::Unary;
using Base::UnaryAssignment;

/// @name Standard Constructors
/// @{

Expand Down Expand Up @@ -130,52 +135,61 @@ namespace gtsam {
/// @name Standard Interface
/// @{

/// Calculate probability for given values `x`,
/// Calculate probability for given values,
/// is just look up in AlgebraicDecisionTree.
double evaluate(const Assignment<Key>& values) const {
return ADT::operator()(values);
}

/// Evaluate probability distribution, sugar.
double operator()(const DiscreteValues& values) const override {
double operator()(const Assignment<Key>& values) const override {
return ADT::operator()(values);
}

/// Calculate error for DiscreteValues `x`, is -log(probability).
double error(const DiscreteValues& values) const override;

/// multiply two factors
DecisionTreeFactor operator*(const DecisionTreeFactor& f) const override {
DecisionTreeFactor operator*(const DecisionTreeFactor& f) const {
return apply(f, Ring::mul);
}

DiscreteFactor::shared_ptr operator*(
const DiscreteFactor::shared_ptr& f) const override;

static double safe_div(const double& a, const double& b);

/// divide by factor f (safely)
DecisionTreeFactor operator/(const DecisionTreeFactor& f) const {
return apply(f, safe_div);
}

/// divide by factor f (pointer version)
DiscreteFactor::shared_ptr operator/(
const DiscreteFactor::shared_ptr& f) const override {
if (auto derived = std::dynamic_pointer_cast<DecisionTreeFactor>(f)) {
return std::make_shared<DecisionTreeFactor>(apply(*derived, safe_div));
} else {
throw std::runtime_error(
"Cannot convert DiscreteFactor to Table Factor");
}
}

/// Convert into a decision tree
DecisionTreeFactor toDecisionTreeFactor() const override { return *this; }

/// Create new factor by summing all values with the same separator values
shared_ptr sum(size_t nrFrontals) const {
DiscreteFactor::shared_ptr sum(size_t nrFrontals) const override {
return combine(nrFrontals, Ring::add);
}

/// Create new factor by summing all values with the same separator values
shared_ptr sum(const Ordering& keys) const {
DiscreteFactor::shared_ptr sum(const Ordering& keys) const override {
return combine(keys, Ring::add);
}

/// Create new factor by maximizing over all values with the same separator.
shared_ptr max(size_t nrFrontals) const {
DiscreteFactor::shared_ptr max(size_t nrFrontals) const override {
return combine(nrFrontals, Ring::max);
}

/// Create new factor by maximizing over all values with the same separator.
shared_ptr max(const Ordering& keys) const {
DiscreteFactor::shared_ptr max(const Ordering& keys) const override {
return combine(keys, Ring::max);
}

Expand All @@ -187,37 +201,37 @@ namespace gtsam {
* Apply unary operator (*this) "op" f
* @param op a unary operator that operates on AlgebraicDecisionTree
*/
DecisionTreeFactor apply(ADT::Unary op) const;
DecisionTreeFactor apply(Unary op) const;

/**
* Apply unary operator (*this) "op" f
* @param op a unary operator that operates on AlgebraicDecisionTree. Takes
* both the assignment and the value.
*/
DecisionTreeFactor apply(ADT::UnaryAssignment op) const;
DecisionTreeFactor apply(UnaryAssignment op) const;

/**
* Apply binary operator (*this) "op" f
* @param f the second argument for op
* @param op a binary operator that operates on AlgebraicDecisionTree
*/
DecisionTreeFactor apply(const DecisionTreeFactor& f, ADT::Binary op) const;
DecisionTreeFactor apply(const DecisionTreeFactor& f, Binary op) const;

/**
* Combine frontal variables using binary operator "op"
* @param nrFrontals nr. of frontal to combine variables in this factor
* @param op a binary operator that operates on AlgebraicDecisionTree
* @return shared pointer to newly created DecisionTreeFactor
*/
shared_ptr combine(size_t nrFrontals, ADT::Binary op) const;
shared_ptr combine(size_t nrFrontals, Binary op) const;

/**
* Combine frontal variables in an Ordering using binary operator "op"
* @param nrFrontals nr. of frontal to combine variables in this factor
* @param op a binary operator that operates on AlgebraicDecisionTree
* @return shared pointer to newly created DecisionTreeFactor
*/
shared_ptr combine(const Ordering& keys, ADT::Binary op) const;
shared_ptr combine(const Ordering& keys, Binary op) const;

/// Enumerate all values into a map from values to double.
std::vector<std::pair<DiscreteValues, double>> enumerate() const;
Expand Down Expand Up @@ -256,6 +270,12 @@ namespace gtsam {
*/
DecisionTreeFactor prune(size_t maxNrAssignments) const;

/**
* Get the number of non-zero values contained in this factor.
* It could be much smaller than `prod_{key}(cardinality(key))`.
*/
uint64_t nrValues() const override { return nrLeaves(); }

/// @}
/// @name Wrapper support
/// @{
Expand Down
29 changes: 16 additions & 13 deletions gtsam/discrete/DiscreteConditional.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -38,13 +38,14 @@ using std::vector;
namespace gtsam {

// Instantiate base class
template class GTSAM_EXPORT
Conditional<DecisionTreeFactor, DiscreteConditional>;
template class GTSAM_EXPORT Conditional<DecisionTreeFactor, DiscreteConditional>;

/* ************************************************************************** */
DiscreteConditional::DiscreteConditional(const size_t nrFrontals,
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Can this now be a DiscreetFactor& ?

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yes it can!

const DecisionTreeFactor& f)
: BaseFactor(f / (*f.sum(nrFrontals))), BaseConditional(nrFrontals) {}
: BaseFactor(f / (*std::dynamic_pointer_cast<DecisionTreeFactor>(
f.sum(nrFrontals)))),
BaseConditional(nrFrontals) {}

/* ************************************************************************** */
DiscreteConditional::DiscreteConditional(size_t nrFrontals,
Expand All @@ -53,15 +54,17 @@ DiscreteConditional::DiscreteConditional(size_t nrFrontals,
: BaseFactor(keys, potentials), BaseConditional(nrFrontals) {}

/* ************************************************************************** */
DiscreteConditional::DiscreteConditional(const DecisionTreeFactor& joint,
const DecisionTreeFactor& marginal)
: BaseFactor(joint / marginal),
BaseConditional(joint.size() - marginal.size()) {}
DiscreteConditional::DiscreteConditional(
const DiscreteFactor::shared_ptr& joint,
const DiscreteFactor::shared_ptr& marginal)
: BaseFactor(*std::dynamic_pointer_cast<DecisionTreeFactor>(
joint->operator/(marginal))),
BaseConditional(joint->size() - marginal->size()) {}

/* ************************************************************************** */
DiscreteConditional::DiscreteConditional(const DecisionTreeFactor& joint,
const DecisionTreeFactor& marginal,
const Ordering& orderedKeys)
DiscreteConditional::DiscreteConditional(
const DiscreteFactor::shared_ptr& joint,
const DiscreteFactor::shared_ptr& marginal, const Ordering& orderedKeys)
: DiscreteConditional(joint, marginal) {
keys_.clear();
keys_.insert(keys_.end(), orderedKeys.begin(), orderedKeys.end());
Expand Down Expand Up @@ -199,7 +202,7 @@ DiscreteConditional::shared_ptr DiscreteConditional::choose(
}

/* ************************************************************************** */
DecisionTreeFactor::shared_ptr DiscreteConditional::likelihood(
DiscreteFactor::shared_ptr DiscreteConditional::likelihood(
const DiscreteValues& frontalValues) const {
// Get the big decision tree with all the levels, and then go down the
// branches based on the value of the frontal variables.
Expand All @@ -224,7 +227,7 @@ DecisionTreeFactor::shared_ptr DiscreteConditional::likelihood(
}

/* ****************************************************************************/
DecisionTreeFactor::shared_ptr DiscreteConditional::likelihood(
DiscreteFactor::shared_ptr DiscreteConditional::likelihood(
size_t frontal) const {
if (nrFrontals() != 1)
throw std::invalid_argument(
Expand Down Expand Up @@ -474,7 +477,7 @@ string DiscreteConditional::html(const KeyFormatter& keyFormatter,

/* ************************************************************************* */
double DiscreteConditional::evaluate(const HybridValues& x) const {
return this->evaluate(x.discrete());
return this->operator()(x.discrete());
}

/* ************************************************************************* */
Expand Down
21 changes: 8 additions & 13 deletions gtsam/discrete/DiscreteConditional.h
Original file line number Diff line number Diff line change
Expand Up @@ -110,16 +110,16 @@ class GTSAM_EXPORT DiscreteConditional
* @brief construct P(X|Y) = f(X,Y)/f(Y) from f(X,Y) and f(Y)
* Assumes but *does not check* that f(Y)=sum_X f(X,Y).
*/
DiscreteConditional(const DecisionTreeFactor& joint,
const DecisionTreeFactor& marginal);
DiscreteConditional(const DiscreteFactor::shared_ptr& joint,
const DiscreteFactor::shared_ptr& marginal);

/**
* @brief construct P(X|Y) = f(X,Y)/f(Y) from f(X,Y) and f(Y)
* Assumes but *does not check* that f(Y)=sum_X f(X,Y).
* Makes sure the keys are ordered as given. Does not check orderedKeys.
*/
DiscreteConditional(const DecisionTreeFactor& joint,
const DecisionTreeFactor& marginal,
DiscreteConditional(const DiscreteFactor::shared_ptr& joint,
const DiscreteFactor::shared_ptr& marginal,
const Ordering& orderedKeys);

/**
Expand Down Expand Up @@ -168,13 +168,8 @@ class GTSAM_EXPORT DiscreteConditional
static_cast<const BaseConditional*>(this)->print(s, formatter);
}

/// Evaluate, just look up in AlgebraicDecisionTree
double evaluate(const DiscreteValues& values) const {
return ADT::operator()(values);
}

using DecisionTreeFactor::error; ///< DiscreteValues version
using DecisionTreeFactor::operator(); ///< DiscreteValues version
using BaseFactor::error; ///< DiscreteValues version
using BaseFactor::operator(); ///< DiscreteValues version

/**
* @brief restrict to given *parent* values.
Expand All @@ -192,11 +187,11 @@ class GTSAM_EXPORT DiscreteConditional
shared_ptr choose(const DiscreteValues& given) const;

/** Convert to a likelihood factor by providing value before bar. */
DecisionTreeFactor::shared_ptr likelihood(
DiscreteFactor::shared_ptr likelihood(
const DiscreteValues& frontalValues) const;

/** Single variable version of likelihood. */
DecisionTreeFactor::shared_ptr likelihood(size_t frontal) const;
DiscreteFactor::shared_ptr likelihood(size_t frontal) const;

/**
* sample
Expand Down
5 changes: 2 additions & 3 deletions gtsam/discrete/DiscreteDistribution.h
Original file line number Diff line number Diff line change
Expand Up @@ -40,7 +40,7 @@ class GTSAM_EXPORT DiscreteDistribution : public DiscreteConditional {
/// Default constructor needed for serialization.
DiscreteDistribution() {}

/// Constructor from factor.
/// Constructor from DecisionTreeFactor.
explicit DiscreteDistribution(const DecisionTreeFactor& f)
: Base(f.size(), f) {}

Expand Down Expand Up @@ -86,8 +86,7 @@ class GTSAM_EXPORT DiscreteDistribution : public DiscreteConditional {
double operator()(size_t value) const;

/// We also want to keep the Base version, taking DiscreteValues:
// TODO(dellaert): does not play well with wrapper!
// using Base::operator();
using Base::operator();

/// Return entire probability mass function.
std::vector<double> pmf() const;
Expand Down
Loading
Loading