mas_handlers/activity_tracker/
mod.rs

1// Copyright 2024, 2025 New Vector Ltd.
2// Copyright 2023, 2024 The Matrix.org Foundation C.I.C.
3//
4// SPDX-License-Identifier: AGPL-3.0-only OR LicenseRef-Element-Commercial
5// Please see LICENSE files in the repository root for full details.
6
7mod bound;
8mod worker;
9
10use std::net::IpAddr;
11
12use chrono::{DateTime, Utc};
13use mas_data_model::{
14    BrowserSession, Clock, CompatSession, Session, personal::session::PersonalSession,
15};
16use mas_storage::BoxRepositoryFactory;
17use tokio_util::{sync::CancellationToken, task::TaskTracker};
18use ulid::Ulid;
19
20pub use self::bound::Bound;
21use self::worker::Worker;
22
23static MESSAGE_QUEUE_SIZE: usize = 1000;
24
25#[derive(Clone, Copy, Debug, PartialOrd, PartialEq, Eq, Hash)]
26enum SessionKind {
27    OAuth2,
28    Compat,
29    /// Session associated with personal access tokens
30    Personal,
31    Browser,
32}
33
34impl SessionKind {
35    const fn as_str(self) -> &'static str {
36        match self {
37            SessionKind::OAuth2 => "oauth2",
38            SessionKind::Compat => "compat",
39            SessionKind::Personal => "personal",
40            SessionKind::Browser => "browser",
41        }
42    }
43}
44
45enum Message {
46    Record {
47        kind: SessionKind,
48        id: Ulid,
49        date_time: DateTime<Utc>,
50        ip: Option<IpAddr>,
51    },
52    Flush(tokio::sync::oneshot::Sender<()>),
53}
54
55#[derive(Clone)]
56pub struct ActivityTracker {
57    channel: tokio::sync::mpsc::Sender<Message>,
58}
59
60impl ActivityTracker {
61    /// Create a new activity tracker
62    ///
63    /// It will spawn the background worker and a loop to flush the tracker on
64    /// the task tracker, and both will shut themselves down, flushing one last
65    /// time, when the cancellation token is cancelled.
66    #[must_use]
67    pub fn new(
68        repository_factory: BoxRepositoryFactory,
69        flush_interval: std::time::Duration,
70        task_tracker: &TaskTracker,
71        cancellation_token: CancellationToken,
72    ) -> Self {
73        let worker = Worker::new(repository_factory);
74        let (sender, receiver) = tokio::sync::mpsc::channel(MESSAGE_QUEUE_SIZE);
75        let tracker = ActivityTracker { channel: sender };
76
77        // Spawn the flush loop and the worker
78        task_tracker.spawn(
79            tracker
80                .clone()
81                .flush_loop(flush_interval, cancellation_token.clone()),
82        );
83        task_tracker.spawn(worker.run(receiver, cancellation_token));
84
85        tracker
86    }
87
88    /// Bind the activity tracker to an IP address.
89    #[must_use]
90    pub fn bind(self, ip: Option<IpAddr>) -> Bound {
91        Bound::new(self, ip)
92    }
93
94    /// Record activity in an OAuth 2.0 session.
95    pub async fn record_oauth2_session(
96        &self,
97        clock: &dyn Clock,
98        session: &Session,
99        ip: Option<IpAddr>,
100    ) {
101        let res = self
102            .channel
103            .send(Message::Record {
104                kind: SessionKind::OAuth2,
105                id: session.id,
106                date_time: clock.now(),
107                ip,
108            })
109            .await;
110
111        if let Err(e) = res {
112            tracing::error!("Failed to record OAuth2 session: {}", e);
113        }
114    }
115
116    /// Record activity in a personal access token session.
117    pub async fn record_personal_access_token_session(
118        &self,
119        clock: &dyn Clock,
120        session: &PersonalSession,
121        ip: Option<IpAddr>,
122    ) {
123        let res = self
124            .channel
125            .send(Message::Record {
126                kind: SessionKind::Personal,
127                id: session.id,
128                date_time: clock.now(),
129                ip,
130            })
131            .await;
132
133        if let Err(e) = res {
134            tracing::error!("Failed to record Personal session: {}", e);
135        }
136    }
137
138    /// Record activity in a compat session.
139    pub async fn record_compat_session(
140        &self,
141        clock: &dyn Clock,
142        compat_session: &CompatSession,
143        ip: Option<IpAddr>,
144    ) {
145        let res = self
146            .channel
147            .send(Message::Record {
148                kind: SessionKind::Compat,
149                id: compat_session.id,
150                date_time: clock.now(),
151                ip,
152            })
153            .await;
154
155        if let Err(e) = res {
156            tracing::error!("Failed to record compat session: {}", e);
157        }
158    }
159
160    /// Record activity in a browser session.
161    pub async fn record_browser_session(
162        &self,
163        clock: &dyn Clock,
164        browser_session: &BrowserSession,
165        ip: Option<IpAddr>,
166    ) {
167        let res = self
168            .channel
169            .send(Message::Record {
170                kind: SessionKind::Browser,
171                id: browser_session.id,
172                date_time: clock.now(),
173                ip,
174            })
175            .await;
176
177        if let Err(e) = res {
178            tracing::error!("Failed to record browser session: {}", e);
179        }
180    }
181
182    /// Manually flush the activity tracker.
183    pub async fn flush(&self) {
184        let (tx, rx) = tokio::sync::oneshot::channel();
185        let res = self.channel.send(Message::Flush(tx)).await;
186
187        match res {
188            Ok(()) => {
189                if let Err(e) = rx.await {
190                    tracing::error!(
191                        error = &e as &dyn std::error::Error,
192                        "Failed to flush activity tracker"
193                    );
194                }
195            }
196            Err(e) => {
197                tracing::error!(
198                    error = &e as &dyn std::error::Error,
199                    "Failed to flush activity tracker"
200                );
201            }
202        }
203    }
204
205    /// Regularly flush the activity tracker.
206    async fn flush_loop(
207        self,
208        interval: std::time::Duration,
209        cancellation_token: CancellationToken,
210    ) {
211        // This guard on the shutdown token is to ensure that if this task crashes for
212        // any reason, the server will shut down
213        let _guard = cancellation_token.clone().drop_guard();
214        let mut interval = tokio::time::interval(interval);
215        interval.set_missed_tick_behavior(tokio::time::MissedTickBehavior::Skip);
216
217        loop {
218            tokio::select! {
219                biased;
220
221                () = cancellation_token.cancelled() => {
222                    // The cancellation token was cancelled, so we should exit
223                    return;
224                }
225
226                // First check if the channel is closed, then check if the timer expired
227                () = self.channel.closed() => {
228                    // The channel was closed, so we should exit
229                    return;
230                }
231
232
233                _ = interval.tick() => {
234                    self.flush().await;
235                }
236            }
237        }
238    }
239}