forked from pageman/sutskever-30-implementations
-
Notifications
You must be signed in to change notification settings - Fork 0
Expand file tree
/
Copy pathrelational_memory_demo.py
More file actions
107 lines (89 loc) · 4.03 KB
/
relational_memory_demo.py
File metadata and controls
107 lines (89 loc) · 4.03 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
"""
Relational Memory Core - Quick Demonstration
Paper 18: Relational RNN - Implementation Task P2-T1
This script provides a concise demonstration of the Relational Memory Core
functionality, showing how it maintains and updates multiple memory slots
that interact via self-attention.
"""
import numpy as np
from relational_memory import RelationalMemory, init_memory
def main():
"""Demonstrate relational memory core capabilities."""
print("\n" + "=" * 80)
print("RELATIONAL MEMORY CORE - QUICK DEMONSTRATION")
print("=" * 80 + "\n")
np.random.seed(42)
# Configuration (as specified in task)
batch_size = 2
num_slots = 4
slot_size = 64
num_heads = 2
print("Configuration:")
print(f" Batch size: {batch_size}")
print(f" Number of memory slots: {num_slots}")
print(f" Slot dimension: {slot_size}")
print(f" Number of attention heads: {num_heads}")
# Create relational memory module
print("\n1. Creating Relational Memory Core...")
rm = RelationalMemory(
num_slots=num_slots,
slot_size=slot_size,
num_heads=num_heads,
use_gate=True,
use_input_attention=True
)
print(" Created successfully!")
# Initialize memory
print("\n2. Initializing memory state...")
memory = rm.reset_memory(batch_size)
print(f" Memory shape: {memory.shape}")
print(f" (batch_size={batch_size}, num_slots={num_slots}, slot_size={slot_size})")
# Forward pass without input
print("\n3. Running self-attention across memory slots (no input)...")
updated_memory, attention_weights = rm.forward(memory)
print(f" Updated memory shape: {updated_memory.shape}")
print(f" Attention weights shape: {attention_weights.shape}")
print(f" Attention weights sum to 1.0: {np.allclose(np.sum(attention_weights, axis=-1), 1.0)}")
# Show attention pattern for first example, first head
print("\n4. Attention pattern (batch 0, head 0):")
print(" How much each slot attends to others:")
attn = attention_weights[0, 0]
for i in range(num_slots):
print(f" Slot {i}: [{', '.join([f'{attn[i,j]:.3f}' for j in range(num_slots)])}]")
# Forward pass with input
print("\n5. Incorporating external input into memory...")
input_vec = np.random.randn(batch_size, 32)
updated_memory_with_input, _ = rm.forward(memory, input_vec)
print(f" Input shape: {input_vec.shape}")
print(f" Updated memory shape: {updated_memory_with_input.shape}")
# Memory changes when input is provided
difference = np.linalg.norm(updated_memory - updated_memory_with_input)
print(f" Difference from no-input case: {difference:.4f}")
print(f" Input successfully incorporated: {difference > 0.1}")
# Simulate a sequence
print("\n6. Processing a sequence of 5 inputs...")
memory_seq = rm.reset_memory(batch_size)
for t in range(5):
input_t = np.random.randn(batch_size, 32)
memory_seq, attn_t = rm.forward(memory_seq, input_t)
mean_attn = np.mean(attn_t[0, 0])
print(f" Step {t+1}: Memory updated, mean attention = {mean_attn:.4f}")
print(f"\n Final memory shape: {memory_seq.shape}")
# Summary
print("\n" + "=" * 80)
print("SUMMARY")
print("=" * 80)
print("\nKey Capabilities Demonstrated:")
print(" 1. Multiple memory slots maintain distinct representations")
print(" 2. Self-attention allows slots to interact and share information")
print(" 3. External inputs can be incorporated into memory")
print(" 4. Memory evolves over time through sequential processing")
print(" 5. Multi-head attention enables different types of relationships")
print("\nRelational Reasoning Advantages:")
print(" - Can represent multiple entities/concepts simultaneously")
print(" - Slots can reason about relationships via attention")
print(" - More structured than single-vector hidden states")
print(" - Better suited for tasks requiring relational reasoning")
print("\n" + "=" * 80 + "\n")
if __name__ == "__main__":
main()