Rust에서 egg 라이브러리를 사용해 E-Graphs(동치 그래프)의 핵심 개념, 합동성, equality saturation, 추출과 비용 함수, 설명(Explanations) API까지 살펴본다.
E-Graphs(일명 equality graphs)는 프로그래밍 언어 공학에서 가장 흥미롭고 빠르게 발전하는 분야 중 하나입니다. 1980년 Gregory Nelson의 박사 학위 논문에서 처음 개발된 이후, 표현식에 대한 동치 관계를 유지하는 기반 데이터 구조가 되었습니다. 처음에는 SMT 솔버를 구동하기 위해 만들어졌지만, 최근 몇 년 사이 그 응용 범위가 크게 확장되었습니다.
이 e-graph 응용의 르네상스는 equality saturation 같은 기법에 의해 촉진되었는데, 이는 동치인 표현식을 더 효율적으로 탐색할 수 있게 해줍니다. E-Graphs의 핵심 혁신은 항 재작성(term rewriting)에서 이른바 “선택 문제(choice problem)”를 해결하는 능력에 있습니다. 전통적인 항 재작성은 파괴적입니다. 즉, 한 번 표현식을 변환하면 원래 표현식은 사라집니다. 이 때문에, 국소적으로는 최적이지만 전역적으로는 최적이 아닐 수 있는 특정 최적화 선택을 강제로 확정해야 합니다. 예를 들어 x + x는 x * 2로 최적화할 수 있지만, 이를 (x + x) - x에 적용하면 덧셈 하나와 뺄셈이 상쇄되어 단순히 x가 된다는 사실을 발견하지 못하게 됩니다.
E-Graphs는 가능한 모든 동치 표현식을 동시에, 그리고 컴팩트한 형태로 유지함으로써 이를 해결합니다. E-Graphs는 다음으로 구성됩니다:
E-Graphs의 강력함은 두 가지 핵심 성질에서 나옵니다:
합동성(Congruence): x가 y와 동치라면 f(x)는 반드시 f(y)와 동치여야 합니다. 이 성질은 동치가 발견될 때 자동으로 유지됩니다.
컴팩트한 표현(Compact Representation): E-Graphs는 비슷한 표현식 사이에서 구조를 공유함으로써, 지수적으로 많은 동치 표현식을 선형 공간에 표현할 수 있습니다.
_equality saturation_이라는 기법을 통해 E-Graphs를 최적화에 사용할 때 과정은 다음과 같이 동작합니다:
Rust에서 E-Graphs를 시작하기 위해, E-Graphs 작업을 위한 강력하고 인체공학적인 API를 제공하는 egg 라이브러리를 사용하겠습니다. 먼저 Cargo.toml에 egg 크레이트를 추가합니다:
[dependencies]
egg = "0.6"
E-Graphs의 근본적인 성질은 단순한 동치 관계(equivalence relation)만이 아니라 합동 관계(congruence relation)를 유지한다는 점입니다. 합동성이란, 두 표현식 x와 y가 동치라면 x를 포함하는 더 큰 표현식은 x를 y로 바꾼 동일한 표현식과 동치여야 함을 의미합니다. 더 공식적으로, x ≡ y라면 임의의 컨텍스트 f에 대해 f(x) ≡ f(y)입니다.
간단한 예제로 이를 확인해 봅시다:
use egg::*;
define_language! {
enum SimpleLanguage {
Num(i32),
"+" = Add([Id; 2]),
Symbol(Symbol),
}
}
fn congruence_example() {
let mut egraph = EGraph::<SimpleLanguage, ()>::default();
// Create expressions: (+ a x) and (+ a y)
let a = egraph.add(SimpleLanguage::Symbol("a".into()));
let x = egraph.add(SimpleLanguage::Symbol("x".into()));
let y = egraph.add(SimpleLanguage::Symbol("y".into()));
let expr1 = egraph.add(SimpleLanguage::Add([a, x])); // (+ a x)
let expr2 = egraph.add(SimpleLanguage::Add([a, y])); // (+ a y)
// Initially, these are in different e-classes
assert_ne!(egraph.find(expr1), egraph.find(expr2));
// When we declare x ≡ y...
egraph.union(x, y);
egraph.rebuild();
// ...congruence ensures (+ a x) ≡ (+ a y)
assert_eq!(egraph.find(expr1), egraph.find(expr2));
}
이 예제에서는:
(+ a x)와 (+ a y)로 시작합니다.union을 사용해 x와 y가 동치임을 선언합니다.(+ a x)와 (+ a y)를 담고 있는 e-class들을 병합합니다.이 합동성 성질은 동치가 발견되는 동안 e-graph에 의해 자동으로 유지됩니다. e-graph가 rebuild 연산을 수행할 때, 다음을 수행합니다:
이러한 합동성의 자동 유지가 e-graph를 항 재작성과 프로그램 최적화에 매우 강력하게 만드는 이유입니다. 큰 표현식에서의 치환을 통해 간접적으로 생기는 동치까지 포함하여, 모든 동치 표현식이 올바르게 식별되도록 보장합니다.
코드로 들어가기 전에, egg의 핵심 추상화를 이해해 봅시다:
Language Definition: egg의 define_language! 매크로로 표현식 언어를 정의해야 합니다. 각 variant는 표현식에서의 연산자 또는 리프 노드를 나타냅니다.
Pattern Matching: egg는 ?x 같은 변수를 사용해 부분 표현식을 매칭하는 간단한 패턴 언어를 사용합니다. 이 패턴이 재작성 규칙을 구동합니다.
Rewrite Rules: rewrite! 매크로로 정의되는, 동치인 표현식 사이의 양방향 변환 규칙입니다.
E-Class Analysis: 동치 클래스에 대한 추가 정보를 유지하기 위해, e-class에 사용자 정의 분석을 붙일 수 있습니다.
이 개념들을 간단한 표현식 언어로 정의해 봅시다:
use egg::*;
// Define our expression language
define_language! {
enum SimpleLanguage {
Num(i32),
"+" = Add([Id; 2]),
"*" = Mul([Id; 2]),
Symbol(Symbol),
}
}
// Define our rewrite rules
fn make_rules() -> Vec<Rewrite<SimpleLanguage, ()>> {
vec![
rewrite!("commute-add"; "(+ ?a ?b)" => "(+ ?b ?a)"),
rewrite!("commute-mul"; "(* ?a ?b)" => "(* ?b ?a)"),
rewrite!("add-0"; "(+ ?a 0)" => "?a"),
rewrite!("mul-0"; "(* ?a 0)" => "0"),
rewrite!("mul-1"; "(* ?a 1)" => "?a"),
]
}
이제 이 규칙들을 사용해 표현식을 단순화하는 간단한 함수를 만들어 봅시다:
/// Simplify an expression using egg
fn simplify(s: &str) -> String {
// Parse the expression
let expr: RecExpr<SimpleLanguage> = s.parse().unwrap();
// Create a Runner and apply our rules
let runner = Runner::default()
.with_expr(&expr)
.run(&make_rules());
// Extract the smallest equivalent expression
let extractor = Extractor::new(&runner.egraph, AstSize);
let (best_cost, best) = extractor.find_best(runner.roots[0]);
println!("Simplified {} to {} with cost {}", expr, best, best_cost);
best.to_string()
}
#[test]
fn simple_tests() {
assert_eq!(simplify("(* 0 42)"), "0");
assert_eq!(simplify("(+ 0 (* 1 foo))"), "foo");
}
E-Graphs를 다룰 때 가장 자주 쓰게 될 핵심 함수는 네 가지입니다:
pub fn union(&mut self, id1: Id, id2: Id) -> bool
union 함수는 id1과 id2를 포함하는 e-class를 병합합니다. e-class가 병합되면 true를, 이미 같은 e-class에 있었다면 false를 반환합니다.
pub fn find(&self, id: Id) -> Id
find 함수는 id를 포함하는 e-class의 대표(representative)를 반환합니다. 대표는 e-class에서 가장 작은 표현식입니다.
pub fn rebuild(&mut self) -> usize
rebuild 함수는 합동 폐쇄(congruence closure)의 단일 라운드를 수행합니다. 합동성 때문에 동치가 된 모든 e-node를 찾고, 이들이 들어 있는 e-class를 병합합니다. 병합된 e-class의 수를 반환합니다.
egg 라이브러리는 여러 고급 기능도 제공합니다:
#[derive(Default)]
struct ConstantFolding;
impl Analysis<Math> for ConstantFolding {
type Data = Option<i32>;
fn merge(&mut self, a: &mut Self::Data, b: Self::Data) -> DidMerge {
if a == &b {
DidMerge(false, false)
} else {
*a = None;
DidMerge(true, false)
}
}
fn make(egraph: &EGraph<Math, Self>, enode: &Math) -> Self::Data {
match enode {
Math::Num(n) => Some(*n),
Math::Add([a, b]) => {
let a = egraph[*a].data;
let b = egraph[*b].data;
a.zip(b).map(|(a, b)| a + b)
}
// ... other operations ...
_ => None,
}
}
}
rw!("div-cancel"; "(/ ?a ?a)" => "1" if is_not_zero("?a"))
Runner API를 사용해 equality saturation을 수행하는 완전한 예제를 살펴봅시다:
use egg::{*, rewrite as rw};
// Define our rewrite rules
let rules: &[Rewrite<SymbolLang, ()>] = &[
rw!("commute-add"; "(+ ?x ?y)" => "(+ ?y ?x)"),
rw!("commute-mul"; "(* ?x ?y)" => "(* ?y ?x)"),
rw!("add-0"; "(+ ?x 0)" => "?x"),
rw!("mul-0"; "(* ?x 0)" => "0"),
rw!("mul-1"; "(* ?x 1)" => "?x"),
];
// Parse the initial expression
let start = "(+ 0 (* 1 a))".parse().unwrap();
// Run equality saturation
let runner = Runner::default().with_expr(&start).run(rules);
// Extract the best expression using AstSize cost function
let extractor = Extractor::new(&runner.egraph, AstSize);
let (best_cost, best_expr) = extractor.find_best(runner.roots[0]);
// The expression simplifies to just "a" with cost 1
assert_eq!(best_expr, "a".parse().unwrap());
assert_eq!(best_cost, 1);
여기서 숫자를 다루는 것처럼 보이지만, SymbolLang는 실제로 모든 것을 문자열로 저장한다는 점에 유의하세요. 이는 데모 목적일 뿐입니다.
E-Graph를 포화(saturation) 상태까지 실행한 뒤에는, 잠재적으로 매우 많은 동치 표현식 중에서 “최선의” 표현식을 추출해야 합니다. egg 라이브러리는 비용 함수에 따라 최적 표현식을 찾을 수 있는 Extractor를 제공합니다.
추출이 어떻게 동작하는지 보여주는 완전한 예제는 다음과 같습니다:
use egg::*;
define_language! {
enum SimpleLanguage {
Num(i32),
"+" = Add([Id; 2]),
"*" = Mul([Id; 2]),
}
}
fn extraction_example() {
let rules: &[Rewrite<SimpleLanguage, ()>] = &[
rewrite!("commute-add"; "(+ ?a ?b)" => "(+ ?b ?a)"),
rewrite!("commute-mul"; "(* ?a ?b)" => "(* ?b ?a)"),
rewrite!("add-0"; "(+ ?a 0)" => "?a"),
rewrite!("mul-0"; "(* ?a 0)" => "0"),
rewrite!("mul-1"; "(* ?a 1)" => "?a"),
];
// Start with (+ 0 (* 1 10))
let start = "(+ 0 (* 1 10))".parse().unwrap();
let runner = Runner::default().with_expr(&start).run(rules);
let (egraph, root) = (runner.egraph, runner.roots[0]);
// Extract the smallest equivalent expression
let mut extractor = Extractor::new(&egraph, AstSize);
let (best_cost, best) = extractor.find_best(root);
// The expression simplifies to just "10" with cost 1
assert_eq!(best_cost, 1);
assert_eq!(best, "10".parse().unwrap());
}
Extractor는 어떤 표현식을 선택할지 결정하기 위해 비용 함수를 사용합니다. 이 예제에서는 표현식의 노드 개수를 세는 AstSize를 사용합니다. 표현식 (+ 0 (* 1 10))은 단지 10으로 단순화되며, 이는 비용이 1(숫자 노드 하나)입니다.
CostFunction 트레이트를 구현하여 사용자 정의 비용 함수를 정의할 수도 있습니다:
#[derive(Clone)]
struct CustomCost;
impl CostFunction<SimpleLanguage> for CustomCost {
type Cost = usize;
fn cost<C>(&mut self, enode: &SimpleLanguage, mut costs: C) -> Self::Cost
where C: FnMut(Id) -> Self::Cost
{
match enode {
SimpleLanguage::Num(_) => 1,
SimpleLanguage::Add([a, b]) => costs(*a) + costs(*b) + 2, // penalize additions more
SimpleLanguage::Mul([a, b]) => costs(*a) + costs(*b) + 1,
}
}
}
이 사용자 정의 비용 함수는 곱셈보다 덧셈에 더 큰 페널티를 부여합니다. 이는 대상 플랫폼에서 곱셈이 더 저렴할 때 유용할 수 있습니다.
egg의 강력한 기능 중 하나는 e-graph에서 두 항이 왜 동치인지 설명할 수 있다는 점입니다. 이는 재작성 규칙을 디버깅하거나 변환을 검증할 때 특히 유용합니다. Explanations API는 한 표현식이 다른 표현식으로 변환되는 정확한 재작성(rewrite) 시퀀스를 보여줄 수 있습니다.
설명을 활성화하려면 runner를 만들 때 with_explanations_enabled() 메서드를 사용합니다:
use egg::*;
fn explanation_example() {
let rules: &[Rewrite<SymbolLang, ()>] = &[
rewrite!("commute-add"; "(+ ?a ?b)" => "(+ ?b ?a)"),
rewrite!("add-0"; "(+ ?a 0)" => "?a"),
];
let start = "(+ (+ a 0) b)".parse().unwrap();
let end = "(+ b a)".parse().unwrap();
let mut runner = Runner::default()
.with_explanations_enabled()
.with_expr(&start)
.run(rules);
// Get explanation of how start transforms to end
println!("{}", runner.explain_equivalence(&start, &end).get_flat_string());
}
설명 출력은 어떤 재작성 규칙이 적용되었는지를 나타내는 주석과 함께 변환의 각 단계를 보여줍니다:
(+ (+ a 0) b)
(+ (Rewrite=> add-0 a) b)
(Rewrite=> commute-add (+ b a))
설명은 두 가지 형태로 제공됩니다:
트리 형식은 공유되는 부분 항이 많은 큰 표현식을 다룰 때 특히 유용합니다. 다음은 트리 설명의 예입니다:
let tree_explanation = runner.explain_equivalence(&start, &end);
println!("{}", tree_explanation.get_tree_string());
또한 explain_existence()를 사용해 특정 항이 e-graph에 존재하는 이유를 이해할 수 있습니다:
// Find out how a specific term came to exist
let term = "(+ a b)".parse().unwrap();
if let Some(explanation) = runner.explain_existence(&term) {
println!("Term exists because: {}", explanation.get_flat_string());
}
이는 예상치 못한 동치를 디버깅하거나, 최적화가 의도대로 동작하는지 검증할 때 특히 가치가 있습니다.