Skip to content
This repository was archived by the owner on Jan 6, 2025. It is now read-only.

Commit 5529bee

Browse files
author
Nate Rush
authored
Merge pull request #169 from ethereum/feat/sharding
Add blockchain sharding
2 parents 7edc277 + 307e7d4 commit 5529bee

File tree

8 files changed

+457
-1
lines changed

8 files changed

+457
-1
lines changed

Makefile

+3
Original file line numberDiff line numberDiff line change
@@ -30,3 +30,6 @@ run-nofinal:
3030

3131
run-binary:
3232
venv/bin/python casper.py rand --protocol binary --report-interval 3
33+
34+
run-sharding:
35+
venv/bin/python casper.py rand --protocol sharding --validators 14 --report-interval 3

casper/protocols/sharding/__init__.py

Whitespace-only changes.

casper/protocols/sharding/block.py

+63
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,63 @@
1+
"""The block module implements the message data structure for a sharded blockchain"""
2+
from casper.message import Message
3+
NUM_MERGE_SHARDS = 2
4+
5+
6+
class Block(Message):
7+
"""Message data structure for a sharded blockchain"""
8+
9+
@classmethod
10+
def is_valid_estimate(cls, estimate):
11+
for key in ['prev_blocks', 'shard_ids']:
12+
if key not in estimate:
13+
return False
14+
if not isinstance(estimate[key], set):
15+
return False
16+
return True
17+
18+
def on_shard(self, shard_id):
19+
return shard_id in self.estimate['shard_ids']
20+
21+
def prev_block(self, shard_id):
22+
"""Returns the previous block on the shard: shard_id
23+
Throws a KeyError if there is no previous block"""
24+
if shard_id not in self.estimate['shard_ids']:
25+
raise KeyError("No previous block on that shard")
26+
27+
for block in self.estimate['prev_blocks']:
28+
# if this block is the genesis, previous is None
29+
if block is None:
30+
return None
31+
32+
# otherwise, return the previous block on that shard
33+
if block.on_shard(shard_id):
34+
return block
35+
36+
raise KeyError("Block on {}, but has no previous block on that shard!".format(shard_id))
37+
38+
@property
39+
def is_merge_block(self):
40+
return len(self.estimate['shard_ids']) == NUM_MERGE_SHARDS
41+
42+
@property
43+
def is_genesis_block(self):
44+
return None in self.estimate['prev_blocks']
45+
46+
def conflicts_with(self, message):
47+
"""Returns true if self is not in the prev blocks of other_message"""
48+
assert isinstance(message, Block), "...expected a block"
49+
50+
return not self.is_in_blockchain(message, '')
51+
52+
def is_in_blockchain(self, block, shard_id):
53+
"""Could be a zero generation ancestor!"""
54+
if not block:
55+
return False
56+
57+
if not block.on_shard(shard_id):
58+
return False
59+
60+
if self == block:
61+
return True
62+
63+
return self.is_in_blockchain(block.prev_block(shard_id), shard_id)
+88
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,88 @@
1+
"""The forkchoice module implements the estimator function a sharded blockchain"""
2+
3+
4+
def get_max_weight_indexes(scores):
5+
"""Returns the keys that map to the max value in a dict.
6+
The max value must be greater than zero."""
7+
8+
max_score = max(scores.values())
9+
10+
assert max_score > 0, "max_score should be greater than zero"
11+
12+
max_weight_estimates = {e for e in scores if scores[e] == max_score}
13+
14+
return max_weight_estimates
15+
16+
17+
def get_scores(starting_block, latest_messages, shard_id):
18+
"""Returns a dict of block => weight"""
19+
scores = dict()
20+
21+
for validator, current_block in latest_messages.items():
22+
if not current_block.on_shard(shard_id):
23+
continue
24+
25+
while current_block and current_block != starting_block:
26+
scores[current_block] = scores.get(current_block, 0) + validator.weight
27+
current_block = current_block.prev_block(shard_id)
28+
29+
return scores
30+
31+
32+
def get_shard_fork_choice(starting_block, children, latest_messages, shard_id):
33+
"""Get the forkchoice for a specific shard"""
34+
35+
scores = get_scores(starting_block, latest_messages, shard_id)
36+
37+
best_block = starting_block
38+
while best_block in children:
39+
curr_scores = dict()
40+
max_score = 0
41+
for child in children[best_block]:
42+
if not child.on_shard(shard_id):
43+
continue # we only select children on the same shard
44+
# can't pick a child that a merge block with a higher shard
45+
if child.is_merge_block:
46+
not_in_forkchoice = False
47+
for shard in child.estimate['shard_ids']:
48+
if len(shard) < len(shard_id):
49+
not_in_forkchoice = True
50+
break
51+
if not_in_forkchoice:
52+
continue
53+
curr_scores[child] = scores.get(child, 0)
54+
max_score = max(curr_scores[child], max_score)
55+
56+
# If no child on shard, or 0 weight block, stop
57+
if max_score == 0:
58+
break
59+
60+
max_weight_children = get_max_weight_indexes(curr_scores)
61+
62+
assert len(max_weight_children) == 1, "... there should be no ties!"
63+
64+
best_block = max_weight_children.pop()
65+
66+
return best_block
67+
68+
69+
def get_all_shards_fork_choice(starting_blocks, children, latest_messages_on_shard):
70+
"""Returns a dict of shard_id -> forkchoice.
71+
Starts from starting block for shard, and stops when it reaches tip"""
72+
73+
# for any shard we have latest messages on, we should have a starting block
74+
for key in starting_blocks.keys():
75+
assert key in latest_messages_on_shard
76+
for key in latest_messages_on_shard.keys():
77+
assert key in latest_messages_on_shard
78+
79+
shards_forkchoice = {
80+
shard_id: get_shard_fork_choice(
81+
starting_blocks[shard_id],
82+
children,
83+
latest_messages_on_shard[shard_id],
84+
shard_id
85+
) for shard_id in starting_blocks
86+
}
87+
88+
return shards_forkchoice
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,105 @@
1+
"""The blockchain plot tool implements functions for plotting sharded blockchain data structures"""
2+
3+
from casper.plot_tool import PlotTool
4+
from casper.safety_oracles.clique_oracle import CliqueOracle
5+
import casper.utils as utils
6+
7+
8+
class ShardingPlotTool(PlotTool):
9+
"""The module contains functions for plotting a blockchain data structure"""
10+
11+
def __init__(self, display, save, view, validator_set):
12+
super().__init__(display, save, 's')
13+
self.view = view
14+
self.validator_set = validator_set
15+
self.starting_blocks = self.view.starting_blocks
16+
self.message_fault_tolerance = dict()
17+
18+
self.blockchain = []
19+
self.communications = []
20+
21+
self.block_fault_tolerance = {}
22+
self.message_labels = {}
23+
self.justifications = {
24+
validator: []
25+
for validator in validator_set
26+
}
27+
28+
def update(self, new_messages=None):
29+
"""Updates displayable items with new messages and paths"""
30+
return
31+
32+
if new_messages is None:
33+
new_messages = []
34+
35+
self._update_new_justifications(new_messages)
36+
self._update_blockchain(new_messages)
37+
self._update_block_fault_tolerance()
38+
self._update_message_labels(new_messages)
39+
40+
def plot(self):
41+
"""Builds relevant edges to display and creates next viewgraph using them"""
42+
return
43+
best_chain_edge = self.get_best_chain()
44+
45+
validator_chain_edges = self.get_validator_chains()
46+
47+
edgelist = []
48+
edgelist.append(utils.edge(self.blockchain, 2, 'grey', 'solid'))
49+
edgelist.append(utils.edge(self.communications, 1, 'black', 'dotted'))
50+
edgelist.append(best_chain_edge)
51+
edgelist.extend(validator_chain_edges)
52+
53+
self.next_viewgraph(
54+
self.view,
55+
self.validator_set,
56+
edges=edgelist,
57+
message_colors=self.block_fault_tolerance,
58+
message_labels=self.message_labels
59+
)
60+
61+
def get_best_chain(self):
62+
"""Returns an edge made of the global forkchoice to genesis"""
63+
best_message = self.view.estimate()
64+
best_chain = utils.build_chain(best_message, None)[:-1]
65+
return utils.edge(best_chain, 5, 'red', 'solid')
66+
67+
def get_validator_chains(self):
68+
"""Returns a list of edges main from validators current forkchoice to genesis"""
69+
vals_chain_edges = []
70+
for validator in self.validator_set:
71+
chain = utils.build_chain(validator.my_latest_message(), None)[:-1]
72+
vals_chain_edges.append(utils.edge(chain, 2, 'blue', 'solid'))
73+
74+
return vals_chain_edges
75+
76+
def _update_new_justifications(self, new_messages):
77+
for message in new_messages:
78+
sender = message.sender
79+
for validator in message.justification:
80+
last_message = self.view.justified_messages[message.justification[validator]]
81+
# only show if new justification
82+
if last_message not in self.justifications[sender]:
83+
self.communications.append([last_message, message])
84+
self.justifications[sender].append(last_message)
85+
86+
def _update_blockchain(self, new_messages):
87+
for message in new_messages:
88+
if message.estimate is not None:
89+
self.blockchain.append([message, message.estimate])
90+
91+
def _update_message_labels(self, new_messages):
92+
for message in new_messages:
93+
self.message_labels[message] = message.sequence_number
94+
95+
def _update_block_fault_tolerance(self):
96+
tip = self.view.estimate()
97+
98+
while tip and self.block_fault_tolerance.get(tip, 0) != len(self.validator_set) - 1:
99+
oracle = CliqueOracle(tip, self.view, self.validator_set)
100+
fault_tolerance, num_node_ft = oracle.check_estimate_safety()
101+
102+
if fault_tolerance > 0:
103+
self.block_fault_tolerance[tip] = num_node_ft
104+
105+
tip = tip.estimate
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,56 @@
1+
from casper.protocols.sharding.sharding_view import ShardingView
2+
from casper.protocols.sharding.block import Block
3+
from casper.protocols.sharding.sharding_plot_tool import ShardingPlotTool
4+
from casper.protocol import Protocol
5+
6+
7+
class ShardingProtocol(Protocol):
8+
View = ShardingView
9+
Message = Block
10+
PlotTool = ShardingPlotTool
11+
12+
shard_genesis_blocks = dict()
13+
curr_shard_idx = 0
14+
curr_shard_ids = ['']
15+
16+
"""Shard ID's look like this:
17+
''
18+
/ \
19+
'0' '1'
20+
/ \ / \
21+
'00''01''10''11'
22+
23+
24+
Blocks can be merge mined between shards if
25+
there is an edge between shards
26+
That is, for ids shard_1 and shard_2, there can be a merge block if
27+
abs(len(shard_1) - len(shard_2)) = 1 AND
28+
for i in range(min(len(shard_1), len(shard_2))):
29+
shard_1[i] = shard_2[i]
30+
"""
31+
32+
@classmethod
33+
def initial_message(cls, validator):
34+
"""Returns a starting block for a shard"""
35+
shard_id = cls.get_next_shard_id()
36+
37+
estimate = {'prev_blocks': set([None]), 'shard_ids': set([shard_id])}
38+
cls.shard_genesis_blocks[shard_id] = Block(estimate, dict(), validator, -1, 0)
39+
40+
return cls.shard_genesis_blocks['']
41+
42+
@classmethod
43+
def get_next_shard_id(cls):
44+
next_id = cls.curr_shard_ids[cls.curr_shard_idx]
45+
cls.curr_shard_idx += 1
46+
47+
if cls.curr_shard_idx == len(cls.curr_shard_ids):
48+
next_ids = []
49+
for shard_id in cls.curr_shard_ids:
50+
next_ids.append(shard_id + '0')
51+
next_ids.append(shard_id + '1')
52+
53+
cls.curr_shard_idx = 0
54+
cls.curr_shard_ids = next_ids
55+
56+
return next_id

0 commit comments

Comments
 (0)