Skip to content
This repository was archived by the owner on May 7, 2025. It is now read-only.

Commit ab31965

Browse files
committed
feat: add einsum expression parser
1 parent 8c1f8da commit ab31965

File tree

2 files changed

+292
-0
lines changed

2 files changed

+292
-0
lines changed

wonnx/src/einsum.rs

Lines changed: 291 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,291 @@
1+
use std::{
2+
collections::{BTreeMap, BTreeSet},
3+
fmt::Display,
4+
};
5+
6+
use thiserror::Error;
7+
8+
#[derive(Debug, Clone, Copy, PartialEq, PartialOrd, Eq, Ord)]
9+
pub struct Subscript(char);
10+
11+
#[derive(Debug, Clone)]
12+
pub enum Subscripts {
13+
Indexes(Vec<Subscript>),
14+
Ellipsis {
15+
start: Vec<Subscript>,
16+
end: Vec<Subscript>,
17+
},
18+
}
19+
20+
/// Represents an Einstein summation expression following the notation described [https://onnx.ai/onnx/operators/onnx__Einsum.html](here).
21+
#[derive(Debug, Clone)]
22+
pub struct Einsum {
23+
inputs: Vec<Subscripts>,
24+
output: Option<Subscripts>,
25+
}
26+
27+
#[derive(Error, Debug)]
28+
pub enum EinsumError {
29+
#[error("invalid character encountered: {0}")]
30+
InvalidCharacter(char),
31+
32+
#[error("the formula has no inputs")]
33+
MissingInputs,
34+
}
35+
36+
impl Subscript {
37+
pub fn from(c: char) -> Subscript {
38+
assert!(c.is_alphabetic());
39+
Subscript(c)
40+
}
41+
}
42+
43+
fn count_indices(inputs: &[Subscripts]) -> BTreeMap<Subscript, u32> {
44+
let mut count = BTreeMap::new();
45+
for input in inputs {
46+
for c in input.subscripts() {
47+
count.entry(c).and_modify(|n| *n += 1).or_insert(1);
48+
}
49+
}
50+
count
51+
}
52+
53+
impl Subscripts {
54+
fn push(&mut self, index: Subscript) {
55+
match self {
56+
Subscripts::Indexes(idxs) => idxs.push(index),
57+
Subscripts::Ellipsis { end, .. } => {
58+
end.push(index);
59+
}
60+
}
61+
}
62+
63+
fn is_empty(&self) -> bool {
64+
match self {
65+
Subscripts::Indexes(idx) => idx.is_empty(),
66+
Subscripts::Ellipsis { start, end } => start.is_empty() && end.is_empty(),
67+
}
68+
}
69+
70+
fn subscripts(&self) -> Vec<Subscript> {
71+
match &self {
72+
Subscripts::Indexes(indices) => indices.clone(),
73+
Subscripts::Ellipsis { start, end } => {
74+
start.iter().chain(end.iter()).cloned().collect()
75+
}
76+
}
77+
}
78+
}
79+
80+
impl Display for Subscript {
81+
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
82+
write!(f, "{}", self.0)
83+
}
84+
}
85+
86+
impl Display for Subscripts {
87+
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
88+
match self {
89+
Subscripts::Indexes(idxs) => {
90+
for i in idxs {
91+
write!(f, "{}", i)?;
92+
}
93+
Ok(())
94+
}
95+
Subscripts::Ellipsis { start, end } => {
96+
for i in start {
97+
write!(f, "{}", i)?;
98+
}
99+
write!(f, "...")?;
100+
for i in end {
101+
write!(f, "{}", i)?;
102+
}
103+
Ok(())
104+
}
105+
}
106+
}
107+
}
108+
109+
impl Einsum {
110+
#[allow(dead_code)]
111+
pub fn from(str: &str) -> Result<Einsum, EinsumError> {
112+
let mut sum = Einsum {
113+
inputs: vec![],
114+
output: None,
115+
};
116+
117+
// Parse up to arrow
118+
let mut chars = str.chars();
119+
let mut current_subscripts = Subscripts::Indexes(vec![]);
120+
let mut after_arrow = false;
121+
while let Some(character) = &chars.next() {
122+
match character {
123+
'-' if chars.next() == Some('>') => {
124+
// Arrow: switch from inputs to outputs
125+
if !current_subscripts.is_empty() {
126+
sum.inputs.push(current_subscripts);
127+
current_subscripts = Subscripts::Indexes(vec![]);
128+
}
129+
if sum.inputs.is_empty() {
130+
return Err(EinsumError::MissingInputs);
131+
}
132+
after_arrow = true;
133+
}
134+
'.' if chars.next() == Some('.') && chars.next() == Some('.') => {
135+
// Ellipsis
136+
current_subscripts = match current_subscripts {
137+
Subscripts::Indexes(idxs) => Subscripts::Ellipsis {
138+
start: idxs,
139+
end: vec![],
140+
},
141+
Subscripts::Ellipsis { .. } => {
142+
return Err(EinsumError::InvalidCharacter('.'))
143+
}
144+
}
145+
}
146+
' ' => {}
147+
',' if !after_arrow => {
148+
// Next input (cannot occur in output)
149+
sum.inputs.push(current_subscripts);
150+
current_subscripts = Subscripts::Indexes(vec![]);
151+
}
152+
c if c.is_alphabetic() => {
153+
current_subscripts.push(Subscript::from(*c));
154+
}
155+
_ => return Err(EinsumError::InvalidCharacter(*character)),
156+
}
157+
}
158+
159+
// If we still have subscripts, they are either the last input or the output
160+
if !current_subscripts.is_empty()
161+
|| matches!(current_subscripts, Subscripts::Ellipsis { .. }) && after_arrow
162+
{
163+
if after_arrow {
164+
sum.output = Some(current_subscripts);
165+
} else {
166+
sum.inputs.push(current_subscripts);
167+
}
168+
}
169+
170+
Ok(sum)
171+
}
172+
173+
fn output_or_implicit_subscripts(&self) -> Vec<Subscript> {
174+
match &self.output {
175+
Some(o) => o.subscripts(),
176+
None => {
177+
// In implicit mode output indices are set to the alphabetically sorted sequence of indices
178+
// appearing exactly once in the equation.
179+
let counts = count_indices(&self.inputs);
180+
let mut keys: Vec<Subscript> = counts
181+
.into_iter()
182+
.filter_map(|(k, v)| if v == 1 { Some(k) } else { None })
183+
.collect();
184+
keys.sort();
185+
keys
186+
}
187+
}
188+
}
189+
190+
fn contraction_indices(&self) -> Vec<Subscript> {
191+
let count = count_indices(&self.inputs);
192+
let mut subscripts: BTreeSet<Subscript> = count
193+
.into_iter()
194+
.filter_map(|(key, value)| if value > 1 { Some(key) } else { None })
195+
.collect();
196+
for c in &self.output_or_implicit_subscripts() {
197+
subscripts.remove(c);
198+
}
199+
subscripts.into_iter().collect()
200+
}
201+
}
202+
203+
impl Display for Einsum {
204+
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
205+
write!(
206+
f,
207+
"{}",
208+
self.inputs
209+
.iter()
210+
.map(|x| x.to_string())
211+
.collect::<Vec<String>>()
212+
.join(",")
213+
)?;
214+
215+
if let Some(output) = &self.output {
216+
write!(f, " -> {}", output)?;
217+
}
218+
Ok(())
219+
}
220+
}
221+
222+
#[cfg(test)]
223+
mod tests {
224+
use super::{count_indices, Einsum, Subscript};
225+
226+
pub fn compare_after_reserialize(formula: &str, expected: &str) {
227+
assert_eq!(Einsum::from(formula).unwrap().to_string(), expected);
228+
}
229+
230+
pub fn expect_fail(formula: &str) {
231+
assert!(Einsum::from(formula).is_err())
232+
}
233+
234+
#[test]
235+
pub fn test_parse_einsum() {
236+
compare_after_reserialize("ij,jk->ik", "ij,jk -> ik");
237+
compare_after_reserialize(" i j, j k -> i k", "ij,jk -> ik");
238+
compare_after_reserialize(" i j-> i k", "ij -> ik");
239+
240+
compare_after_reserialize("a ...d,x... z->a ...z", "a...d,x...z -> a...z");
241+
compare_after_reserialize(" ...d,x... z->a ...", "...d,x...z -> a...");
242+
compare_after_reserialize("a...", "a...");
243+
compare_after_reserialize("a ...d,x... z->...", "a...d,x...z -> ...");
244+
245+
expect_fail("ij- >ik");
246+
expect_fail("->ik");
247+
expect_fail("a ...d,x... z->a . ..z");
248+
expect_fail("a...b...c");
249+
expect_fail("a....b...c");
250+
expect_fail("a..b...c");
251+
}
252+
253+
#[test]
254+
pub fn test_indices() {
255+
let es = Einsum::from("ij,jk->ik").unwrap();
256+
let out = count_indices(&es.inputs);
257+
assert_eq!(out.len(), 3);
258+
assert_eq!(out[&Subscript::from('i')], 1);
259+
assert_eq!(out[&Subscript::from('j')], 2);
260+
assert_eq!(out[&Subscript::from('k')], 1);
261+
262+
let es = Einsum::from("i...k,k...m->i...m").unwrap();
263+
let out = count_indices(&es.inputs);
264+
println!("{:?}", out);
265+
assert_eq!(out.len(), 5);
266+
assert_eq!(out[&Subscript::from('i')], 1);
267+
assert_eq!(out[&Subscript::from('j')], 1);
268+
assert_eq!(out[&Subscript::from('k')], 2);
269+
assert_eq!(out[&Subscript::from('l')], 1);
270+
assert_eq!(out[&Subscript::from('m')], 1);
271+
}
272+
273+
#[test]
274+
pub fn test_analysis() {
275+
let es = Einsum::from("ij,jk->ik").unwrap();
276+
assert_eq!(es.contraction_indices(), vec![Subscript::from('j')]);
277+
278+
let es = Einsum::from("ij,jk").unwrap();
279+
assert_eq!(
280+
es.output_or_implicit_subscripts(),
281+
vec![Subscript::from('i'), Subscript::from('k')]
282+
);
283+
assert_eq!(es.contraction_indices(), vec![Subscript::from('j')]);
284+
285+
let transpose = Einsum::from("ba").unwrap();
286+
assert_eq!(
287+
transpose.output_or_implicit_subscripts(),
288+
vec![Subscript::from('a'), Subscript::from('b')]
289+
);
290+
}
291+
}

wonnx/src/lib.rs

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,5 @@
11
mod compiler;
2+
mod einsum;
23
mod gpu;
34
mod ir;
45
pub mod onnx;

0 commit comments

Comments
 (0)