Codebase list matrix-synapse / b5f99d1
Merge branch 'debian/unstable' into debian/buster-backports Andrej Shadura 3 years ago
337 changed file(s) with 8565 addition(s) and 4982 deletion(s). Raw diff Collapse all Expand all
Binary diff not shown
0 Synapse 1.20.1 (2020-09-24)
1 ===========================
2
3 Bugfixes
4 --------
5
6 - Fix a bug introduced in v1.20.0 which caused the `synapse_port_db` script to fail. ([\#8386](https://github.com/matrix-org/synapse/issues/8386))
7 - Fix a bug introduced in v1.20.0 which caused variables to be incorrectly escaped in Jinja2 templates. ([\#8394](https://github.com/matrix-org/synapse/issues/8394))
8
9
10 Synapse 1.20.0 (2020-09-22)
11 ===========================
12
13 No significant changes since v1.20.0rc5.
14
15 Removal warning
16 ---------------
17
18 Historically, the [Synapse Admin
19 API](https://github.com/matrix-org/synapse/tree/master/docs) has been
20 accessible under the `/_matrix/client/api/v1/admin`,
21 `/_matrix/client/unstable/admin`, `/_matrix/client/r0/admin` and
22 `/_synapse/admin` prefixes. In a future release, we will be dropping support
23 for accessing Synapse's Admin API using the `/_matrix/client/*` prefixes. This
24 makes it easier for homeserver admins to lock down external access to the Admin
25 API endpoints.
26
27 Synapse 1.20.0rc5 (2020-09-18)
28 ==============================
29
30 In addition to the below, Synapse 1.20.0rc5 also includes the bug fix that was included in 1.19.3.
31
32 Features
33 --------
34
35 - Add flags to the `/versions` endpoint for whether new rooms default to using E2EE. ([\#8343](https://github.com/matrix-org/synapse/issues/8343))
36
37
38 Bugfixes
39 --------
40
41 - Fix rate limiting of federation `/send` requests. ([\#8342](https://github.com/matrix-org/synapse/issues/8342))
42 - Fix a longstanding bug where back pagination over federation could get stuck if it failed to handle a received event. ([\#8349](https://github.com/matrix-org/synapse/issues/8349))
43
44
45 Internal Changes
46 ----------------
47
48 - Blacklist [MSC2753](https://github.com/matrix-org/matrix-doc/pull/2753) SyTests until it is implemented. ([\#8285](https://github.com/matrix-org/synapse/issues/8285))
49
50
51 Synapse 1.19.3 (2020-09-18)
52 ===========================
53
54 Bugfixes
55 --------
56
57 - Partially mitigate bug where newly joined servers couldn't get past events in a room when there is a malformed event. ([\#8350](https://github.com/matrix-org/synapse/issues/8350))
58
59
60 Synapse 1.20.0rc4 (2020-09-16)
61 ==============================
62
63 Synapse 1.20.0rc4 is identical to 1.20.0rc3, with the addition of the security fix that was included in 1.19.2.
64
65
66 Synapse 1.19.2 (2020-09-16)
67 ===========================
68
69 Due to the issue below server admins are encouraged to upgrade as soon as possible.
70
71 Bugfixes
72 --------
73
74 - Fix joining rooms over federation that include malformed events. ([\#8324](https://github.com/matrix-org/synapse/issues/8324))
75
76
77 Synapse 1.20.0rc3 (2020-09-11)
78 ==============================
79
80 Bugfixes
81 --------
82
83 - Fix a bug introduced in v1.20.0rc1 where the wrong exception was raised when invalid JSON data is encountered. ([\#8291](https://github.com/matrix-org/synapse/issues/8291))
84
85
86 Synapse 1.20.0rc2 (2020-09-09)
87 ==============================
88
89 Bugfixes
90 --------
91
92 - Fix a bug introduced in v1.20.0rc1 causing some features related to notifications to misbehave following the implementation of unread counts. ([\#8280](https://github.com/matrix-org/synapse/issues/8280))
93
94
95 Synapse 1.20.0rc1 (2020-09-08)
96 ==============================
97
98 Removal warning
99 ---------------
100
101 Some older clients used a [disallowed character](https://matrix.org/docs/spec/client_server/r0.6.1#post-matrix-client-r0-register-email-requesttoken) (`:`) in the `client_secret` parameter of various endpoints. The incorrect behaviour was allowed for backwards compatibility, but is now being removed from Synapse as most users have updated their client. Further context can be found at [\#6766](https://github.com/matrix-org/synapse/issues/6766).
102
103 Features
104 --------
105
106 - Add an endpoint to query your shared rooms with another user as an implementation of [MSC2666](https://github.com/matrix-org/matrix-doc/pull/2666). ([\#7785](https://github.com/matrix-org/synapse/issues/7785))
107 - Iteratively encode JSON to avoid blocking the reactor. ([\#8013](https://github.com/matrix-org/synapse/issues/8013), [\#8116](https://github.com/matrix-org/synapse/issues/8116))
108 - Add support for shadow-banning users (ignoring any message send requests). ([\#8034](https://github.com/matrix-org/synapse/issues/8034), [\#8092](https://github.com/matrix-org/synapse/issues/8092), [\#8095](https://github.com/matrix-org/synapse/issues/8095), [\#8142](https://github.com/matrix-org/synapse/issues/8142), [\#8152](https://github.com/matrix-org/synapse/issues/8152), [\#8157](https://github.com/matrix-org/synapse/issues/8157), [\#8158](https://github.com/matrix-org/synapse/issues/8158), [\#8176](https://github.com/matrix-org/synapse/issues/8176))
109 - Use the default template file when its equivalent is not found in a custom template directory. ([\#8037](https://github.com/matrix-org/synapse/issues/8037), [\#8107](https://github.com/matrix-org/synapse/issues/8107), [\#8252](https://github.com/matrix-org/synapse/issues/8252))
110 - Add unread messages count to sync responses, as specified in [MSC2654](https://github.com/matrix-org/matrix-doc/pull/2654). ([\#8059](https://github.com/matrix-org/synapse/issues/8059), [\#8254](https://github.com/matrix-org/synapse/issues/8254), [\#8270](https://github.com/matrix-org/synapse/issues/8270), [\#8274](https://github.com/matrix-org/synapse/issues/8274))
111 - Optimise `/federation/v1/user/devices/` API by only returning devices with encryption keys. ([\#8198](https://github.com/matrix-org/synapse/issues/8198))
112
113
114 Bugfixes
115 --------
116
117 - Fix a memory leak by limiting the length of time that messages will be queued for a remote server that has been unreachable. ([\#7864](https://github.com/matrix-org/synapse/issues/7864))
118 - Fix `Re-starting finished log context PUT-nnnn` warning when event persistence failed. ([\#8081](https://github.com/matrix-org/synapse/issues/8081))
119 - Synapse now correctly enforces the valid characters in the `client_secret` parameter used in various endpoints. ([\#8101](https://github.com/matrix-org/synapse/issues/8101))
120 - Fix a bug introduced in v1.7.2 impacting message retention policies that would allow federated homeservers to dictate a retention period that's lower than the configured minimum allowed duration in the configuration file. ([\#8104](https://github.com/matrix-org/synapse/issues/8104))
121 - Fix a long-standing bug where invalid JSON would be accepted by Synapse. ([\#8106](https://github.com/matrix-org/synapse/issues/8106))
122 - Fix a bug introduced in Synapse v1.12.0 which could cause `/sync` requests to fail with a 404 if you had a very old outstanding room invite. ([\#8110](https://github.com/matrix-org/synapse/issues/8110))
123 - Return a proper error code when the rooms of an invalid group are requested. ([\#8129](https://github.com/matrix-org/synapse/issues/8129))
124 - Fix a bug which could cause a leaked postgres connection if synapse was set to daemonize. ([\#8131](https://github.com/matrix-org/synapse/issues/8131))
125 - Clarify the error code if a user tries to register with a numeric ID. This bug was introduced in v1.15.0. ([\#8135](https://github.com/matrix-org/synapse/issues/8135))
126 - Fix a bug where appservices with ratelimiting disabled would still be ratelimited when joining rooms. This bug was introduced in v1.19.0. ([\#8139](https://github.com/matrix-org/synapse/issues/8139))
127 - Fix logging in via OpenID Connect with a provider that uses integer user IDs. ([\#8190](https://github.com/matrix-org/synapse/issues/8190))
128 - Fix a longstanding bug where user directory updates could break when unexpected profile data was included in events. ([\#8223](https://github.com/matrix-org/synapse/issues/8223))
129 - Fix a longstanding bug where stats updates could break when unexpected profile data was included in events. ([\#8226](https://github.com/matrix-org/synapse/issues/8226))
130 - Fix slow start times for large servers by removing a table scan of the `users` table from startup code. ([\#8271](https://github.com/matrix-org/synapse/issues/8271))
131
132
133 Updates to the Docker image
134 ---------------------------
135
136 - Fix builds of the Docker image on non-x86 platforms. ([\#8144](https://github.com/matrix-org/synapse/issues/8144))
137 - Added curl for healthcheck support and readme updates for the change. Contributed by @maquis196. ([\#8147](https://github.com/matrix-org/synapse/issues/8147))
138
139
140 Improved Documentation
141 ----------------------
142
143 - Link to matrix-synapse-rest-password-provider in the password provider documentation. ([\#8111](https://github.com/matrix-org/synapse/issues/8111))
144 - Updated documentation to note that Synapse does not follow `HTTP 308` redirects due to an upstream library not supporting them. Contributed by Ryan Cole. ([\#8120](https://github.com/matrix-org/synapse/issues/8120))
145 - Explain better what GDPR-erased means when deactivating a user. ([\#8189](https://github.com/matrix-org/synapse/issues/8189))
146
147
148 Internal Changes
149 ----------------
150
151 - Add filter `name` to the `/users` admin API, which filters by user ID or displayname. Contributed by Awesome Technologies Innovationslabor GmbH. ([\#7377](https://github.com/matrix-org/synapse/issues/7377), [\#8163](https://github.com/matrix-org/synapse/issues/8163))
152 - Reduce run times of some unit tests by advancing the reactor a fewer number of times. ([\#7757](https://github.com/matrix-org/synapse/issues/7757))
153 - Don't fail `/submit_token` requests on incorrect session ID if `request_token_inhibit_3pid_errors` is turned on. ([\#7991](https://github.com/matrix-org/synapse/issues/7991))
154 - Convert various parts of the codebase to async/await. ([\#8071](https://github.com/matrix-org/synapse/issues/8071), [\#8072](https://github.com/matrix-org/synapse/issues/8072), [\#8074](https://github.com/matrix-org/synapse/issues/8074), [\#8075](https://github.com/matrix-org/synapse/issues/8075), [\#8076](https://github.com/matrix-org/synapse/issues/8076), [\#8087](https://github.com/matrix-org/synapse/issues/8087), [\#8100](https://github.com/matrix-org/synapse/issues/8100), [\#8119](https://github.com/matrix-org/synapse/issues/8119), [\#8121](https://github.com/matrix-org/synapse/issues/8121), [\#8133](https://github.com/matrix-org/synapse/issues/8133), [\#8156](https://github.com/matrix-org/synapse/issues/8156), [\#8162](https://github.com/matrix-org/synapse/issues/8162), [\#8166](https://github.com/matrix-org/synapse/issues/8166), [\#8168](https://github.com/matrix-org/synapse/issues/8168), [\#8173](https://github.com/matrix-org/synapse/issues/8173), [\#8191](https://github.com/matrix-org/synapse/issues/8191), [\#8192](https://github.com/matrix-org/synapse/issues/8192), [\#8193](https://github.com/matrix-org/synapse/issues/8193), [\#8194](https://github.com/matrix-org/synapse/issues/8194), [\#8195](https://github.com/matrix-org/synapse/issues/8195), [\#8197](https://github.com/matrix-org/synapse/issues/8197), [\#8199](https://github.com/matrix-org/synapse/issues/8199), [\#8200](https://github.com/matrix-org/synapse/issues/8200), [\#8201](https://github.com/matrix-org/synapse/issues/8201), [\#8202](https://github.com/matrix-org/synapse/issues/8202), [\#8207](https://github.com/matrix-org/synapse/issues/8207), [\#8213](https://github.com/matrix-org/synapse/issues/8213), [\#8214](https://github.com/matrix-org/synapse/issues/8214))
155 - Remove some unused database functions. ([\#8085](https://github.com/matrix-org/synapse/issues/8085))
156 - Add type hints to various parts of the codebase. ([\#8090](https://github.com/matrix-org/synapse/issues/8090), [\#8127](https://github.com/matrix-org/synapse/issues/8127), [\#8187](https://github.com/matrix-org/synapse/issues/8187), [\#8241](https://github.com/matrix-org/synapse/issues/8241), [\#8140](https://github.com/matrix-org/synapse/issues/8140), [\#8183](https://github.com/matrix-org/synapse/issues/8183), [\#8232](https://github.com/matrix-org/synapse/issues/8232), [\#8235](https://github.com/matrix-org/synapse/issues/8235), [\#8237](https://github.com/matrix-org/synapse/issues/8237), [\#8244](https://github.com/matrix-org/synapse/issues/8244))
157 - Return the previous stream token if a non-member event is a duplicate. ([\#8093](https://github.com/matrix-org/synapse/issues/8093), [\#8112](https://github.com/matrix-org/synapse/issues/8112))
158 - Separate `get_current_token` into two since there are two different use cases for it. ([\#8113](https://github.com/matrix-org/synapse/issues/8113))
159 - Remove `ChainedIdGenerator`. ([\#8123](https://github.com/matrix-org/synapse/issues/8123))
160 - Reduce the amount of whitespace in JSON stored and sent in responses. ([\#8124](https://github.com/matrix-org/synapse/issues/8124))
161 - Update the test federation client to handle streaming responses. ([\#8130](https://github.com/matrix-org/synapse/issues/8130))
162 - Micro-optimisations to `get_auth_chain_ids`. ([\#8132](https://github.com/matrix-org/synapse/issues/8132))
163 - Refactor `StreamIdGenerator` and `MultiWriterIdGenerator` to have the same interface. ([\#8161](https://github.com/matrix-org/synapse/issues/8161))
164 - Add functions to `MultiWriterIdGen` used by events stream. ([\#8164](https://github.com/matrix-org/synapse/issues/8164), [\#8179](https://github.com/matrix-org/synapse/issues/8179))
165 - Fix tests that were broken due to the merge of 1.19.1. ([\#8167](https://github.com/matrix-org/synapse/issues/8167))
166 - Make `SlavedIdTracker.advance` have the same interface as `MultiWriterIDGenerator`. ([\#8171](https://github.com/matrix-org/synapse/issues/8171))
167 - Remove unused `is_guest` parameter from, and add safeguard to, `MessageHandler.get_room_data`. ([\#8174](https://github.com/matrix-org/synapse/issues/8174), [\#8181](https://github.com/matrix-org/synapse/issues/8181))
168 - Standardize the mypy configuration. ([\#8175](https://github.com/matrix-org/synapse/issues/8175))
169 - Refactor some of `LoginRestServlet`'s helper methods, and move them to `AuthHandler` for easier reuse. ([\#8182](https://github.com/matrix-org/synapse/issues/8182))
170 - Fix `wait_for_stream_position` to allow multiple waiters on same stream ID. ([\#8196](https://github.com/matrix-org/synapse/issues/8196))
171 - Make `MultiWriterIDGenerator` work for streams that use negative values. ([\#8203](https://github.com/matrix-org/synapse/issues/8203))
172 - Refactor queries for device keys and cross-signatures. ([\#8204](https://github.com/matrix-org/synapse/issues/8204), [\#8205](https://github.com/matrix-org/synapse/issues/8205), [\#8222](https://github.com/matrix-org/synapse/issues/8222), [\#8224](https://github.com/matrix-org/synapse/issues/8224), [\#8225](https://github.com/matrix-org/synapse/issues/8225), [\#8231](https://github.com/matrix-org/synapse/issues/8231), [\#8233](https://github.com/matrix-org/synapse/issues/8233), [\#8234](https://github.com/matrix-org/synapse/issues/8234))
173 - Fix type hints for functions decorated with `@cached`. ([\#8240](https://github.com/matrix-org/synapse/issues/8240))
174 - Remove obsolete `order` field from federation send queues. ([\#8245](https://github.com/matrix-org/synapse/issues/8245))
175 - Stop sub-classing from object. ([\#8249](https://github.com/matrix-org/synapse/issues/8249))
176 - Add more logging to debug slow startup. ([\#8264](https://github.com/matrix-org/synapse/issues/8264))
177 - Do not attempt to upgrade database schema on worker processes. ([\#8266](https://github.com/matrix-org/synapse/issues/8266), [\#8276](https://github.com/matrix-org/synapse/issues/8276))
178
179
0180 Synapse 1.19.1 (2020-08-27)
1181 ===========================
2182
2323 from twisted.web.http_headers import Headers
2424
2525
26 class HttpClient(object):
26 class HttpClient:
2727 """ Interface for talking json over http
2828 """
2929
168168 return d
169169
170170
171 class _RawProducer(object):
171 class _RawProducer:
172172 def __init__(self, data):
173173 self.data = data
174174 self.body = data
185185 pass
186186
187187
188 class _JsonProducer(object):
188 class _JsonProducer:
189189 """ Used by the twisted http client to create the HTTP body from json
190190 """
191191
140140 curses.endwin()
141141
142142
143 class Callback(object):
143 class Callback:
144144 def __init__(self, stdio):
145145 self.stdio = stdio
146146
5454 logging.exception(failure)
5555
5656
57 class InputOutput(object):
57 class InputOutput:
5858 """ This is responsible for basic I/O so that a user can interact with
5959 the example app.
6060 """
131131 self.io.print_log(msg)
132132
133133
134 class Room(object):
134 class Room:
135135 """ Used to store (in memory) the current membership state of a room, and
136136 which home servers we should send PDUs associated with the room to.
137137 """
0 matrix-synapse (1.20.1-1) unstable; urgency=medium
1
2 * New upstream release.
3
4 -- Andrej Shadura <andrewsh@debian.org> Fri, 25 Sep 2020 18:06:35 +0200
5
6 matrix-synapse (1.20.0-1) unstable; urgency=medium
7
8 * New upstream release.
9 * Bump python3-canonicaljson version.
10
11 -- Andrej Shadura <andrewsh@debian.org> Wed, 23 Sep 2020 10:03:06 +0200
12
13 matrix-synapse (1.19.3-1) unstable; urgency=high
14
15 * New upstream release.
16
17 -- Andrej Shadura <andrewsh@debian.org> Sat, 19 Sep 2020 14:22:15 +0300
18
19 matrix-synapse (1.19.2-1) unstable; urgency=high
20
21 * New upstream release.
22
23 -- Andrej Shadura <andrewsh@debian.org> Wed, 16 Sep 2020 17:44:26 +0300
24
025 matrix-synapse (1.19.1-1~bpo10+3) buster-backports; urgency=high
126
227 * Backport an upstream fix:
1212 python3-bcrypt,
1313 python3-bleach (>= 1.4.2),
1414 python3-blist,
15 python3-canonicaljson (>= 1.2.0),
15 python3-canonicaljson (>= 1.3.0~),
1616 python3-daemonize,
1717 python3-frozendict (>= 1),
1818 python3-idna,
1818 FROM docker.io/python:${PYTHON_VERSION}-slim as builder
1919
2020 # install the OS build deps
21
22
2321 RUN apt-get update && apt-get install -y \
2422 build-essential \
23 libffi-dev \
24 libjpeg-dev \
2525 libpq-dev \
26 libssl-dev \
27 libwebp-dev \
28 libxml++2.6-dev \
29 libxslt1-dev \
30 zlib1g-dev \
2631 && rm -rf /var/lib/apt/lists/*
2732
2833 # Build dependencies that are not available as wheels, to speed up rebuilds
5459 FROM docker.io/python:${PYTHON_VERSION}-slim
5560
5661 RUN apt-get update && apt-get install -y \
62 curl \
63 gosu \
64 libjpeg62-turbo \
5765 libpq5 \
66 libwebp6 \
5867 xmlsec1 \
59 gosu \
6068 && rm -rf /var/lib/apt/lists/*
6169
6270 COPY --from=builder /install /usr/local
6876 EXPOSE 8008/tcp 8009/tcp 8448/tcp
6977
7078 ENTRYPOINT ["/start.py"]
79
80 HEALTHCHECK --interval=1m --timeout=5s \
81 CMD curl -fSs http://localhost:8008/health || exit 1
161161
162162 You can choose to build a different docker image by changing the value of the `-f` flag to
163163 point to another Dockerfile.
164
165 ## Disabling the healthcheck
166
167 If you are using a non-standard port or tls inside docker you can disable the healthcheck
168 whilst running the above `docker run` commands.
169
170 ```
171 --no-healthcheck
172 ```
173 ## Setting custom healthcheck on docker run
174
175 If you wish to point the healthcheck at a different port with docker command, add the following
176
177 ```
178 --health-cmd 'curl -fSs http://localhost:1234/health'
179 ```
180
181 ## Setting the healthcheck in docker-compose file
182
183 You can add the following to set a custom healthcheck in a docker compose file.
184 You will need version >2.1 for this to work.
185
186 ```
187 healthcheck:
188 test: ["CMD", "curl", "-fSs", "http://localhost:8008/health"]
189 interval: 1m
190 timeout: 10s
191 retries: 3
192 ```
107107
108108 GET /_synapse/admin/v2/users?from=0&limit=10&guests=false
109109
110 To use it, you will need to authenticate by providing an `access_token` for a
110 To use it, you will need to authenticate by providing an ``access_token`` for a
111111 server admin: see `README.rst <README.rst>`_.
112112
113113 The parameter ``from`` is optional but used for pagination, denoting the
118118 The parameter ``limit`` is optional but is used for pagination, denoting the
119119 maximum number of items to return in this call. Defaults to ``100``.
120120
121 The parameter ``user_id`` is optional and filters to only users with user IDs
122 that contain this value.
121 The parameter ``user_id`` is optional and filters to only return users with user IDs
122 that contain this value. This parameter is ignored when using the ``name`` parameter.
123
124 The parameter ``name`` is optional and filters to only return users with user ID localparts
125 **or** displaynames that contain this value.
123126
124127 The parameter ``guests`` is optional and if ``false`` will **exclude** guest users.
125128 Defaults to ``true`` to include guest users.
210213
211214 This API deactivates an account. It removes active access tokens, resets the
212215 password, and deletes third-party IDs (to prevent the user requesting a
213 password reset). It can also mark the user as GDPR-erased (stopping their data
214 from distributed further, and deleting it entirely if there are no other
215 references to it).
216 password reset).
217
218 It can also mark the user as GDPR-erased. This means messages sent by the
219 user will still be visible by anyone that was in the room when these messages
220 were sent, but hidden from users joining the room afterwards.
216221
217222 The api is::
218223
4646 proxy: see [reverse_proxy.md](<reverse_proxy.md>) for instructions on how to correctly
4747 configure a reverse proxy.
4848
49 ### Known issues
50
51 **HTTP `308 Permanent Redirect` redirects are not followed**: Due to missing features
52 in the HTTP library used by Synapse, 308 redirects are currently not followed by
53 federating servers, which can cause `M_UNKNOWN` or `401 Unauthorized` errors. This
54 may affect users who are redirecting apex-to-www (e.g. `example.com` -> `www.example.com`),
55 and especially users of the Kubernetes *Nginx Ingress* module, which uses 308 redirect
56 codes by default. For those Kubernetes users, [this Stackoverflow post](https://stackoverflow.com/a/52617528/5096871)
57 might be helpful. For other users, switching to a `301 Moved Permanently` code may be
58 an option. 308 redirect codes will be supported properly in a future
59 release of Synapse.
60
4961 ## Running a demo federation of Synapses
5062
5163 If you want to get up and running quickly with a trio of homeservers in a
1313
1414 * [matrix-synapse-ldap3](https://github.com/matrix-org/matrix-synapse-ldap3/)
1515 * [matrix-synapse-shared-secret-auth](https://github.com/devture/matrix-synapse-shared-secret-auth)
16 * [matrix-synapse-rest-password-provider](https://github.com/ma1uta/matrix-synapse-rest-password-provider)
1617
1718 ## Required methods
1819
377377 # min_lifetime: 1d
378378 # max_lifetime: 1y
379379
380 # Retention policy limits. If set, a user won't be able to send a
381 # 'm.room.retention' event which features a 'min_lifetime' or a 'max_lifetime'
382 # that's not within this range. This is especially useful in closed federations,
383 # in which server admins can make sure every federating server applies the same
384 # rules.
380 # Retention policy limits. If set, and the state of a room contains a
381 # 'm.room.retention' event in its state which contains a 'min_lifetime' or a
382 # 'max_lifetime' that's out of these bounds, Synapse will cap the room's policy
383 # to these limits when running purge jobs.
385384 #
386385 #allowed_lifetime_min: 1d
387386 #allowed_lifetime_max: 1y
407406 # (e.g. every 12h), but not want that purge to be performed by a job that's
408407 # iterating over every room it knows, which could be heavy on the server.
409408 #
409 # If any purge job is configured, it is strongly recommended to have at least
410 # a single job with neither 'shortest_max_lifetime' nor 'longest_max_lifetime'
411 # set, or one job without 'shortest_max_lifetime' and one job without
412 # 'longest_max_lifetime' set. Otherwise some rooms might be ignored, even if
413 # 'allowed_lifetime_min' and 'allowed_lifetime_max' are set, because capping a
414 # room's policy to these values is done after the policies are retrieved from
415 # Synapse's database (which is done using the range specified in a purge job's
416 # configuration).
417 #
410418 #purge_jobs:
411 # - shortest_max_lifetime: 1d
412 # longest_max_lifetime: 3d
419 # - longest_max_lifetime: 3d
413420 # interval: 12h
414421 # - shortest_max_lifetime: 3d
415 # longest_max_lifetime: 1y
416422 # interval: 1d
417423
418424 # Inhibits the /requestToken endpoints from returning an error that might leak
20012007 # Directory in which Synapse will try to find the template files below.
20022008 # If not set, default templates from within the Synapse package will be used.
20032009 #
2004 # DO NOT UNCOMMENT THIS SETTING unless you want to customise the templates.
2005 # If you *do* uncomment it, you will need to make sure that all the templates
2006 # below are in the directory.
2010 # Do not uncomment this setting unless you want to customise the templates.
20072011 #
20082012 # Synapse will look for the following templates in this directory:
20092013 #
00 [Unit]
11 Description=Synapse %i
22 AssertPathExists=/etc/matrix-synapse/workers/%i.yaml
3
34 # This service should be restarted when the synapse target is restarted.
45 PartOf=matrix-synapse.target
6
7 # if this is started at the same time as the main, let the main process start
8 # first, to initialise the database schema.
9 After=matrix-synapse.service
510
611 [Service]
712 Type=notify
00 [mypy]
11 namespace_packages = True
2 plugins = mypy_zope:plugin
2 plugins = mypy_zope:plugin, scripts-dev/mypy_synapse_plugin.py
33 follow_imports = silent
44 check_untyped_defs = True
55 show_error_codes = True
66 show_traceback = True
77 mypy_path = stubs
8 files =
9 synapse/api,
10 synapse/appservice,
11 synapse/config,
12 synapse/event_auth.py,
13 synapse/events/builder.py,
14 synapse/events/spamcheck.py,
15 synapse/federation,
16 synapse/handlers/auth.py,
17 synapse/handlers/cas_handler.py,
18 synapse/handlers/directory.py,
19 synapse/handlers/events.py,
20 synapse/handlers/federation.py,
21 synapse/handlers/identity.py,
22 synapse/handlers/initial_sync.py,
23 synapse/handlers/message.py,
24 synapse/handlers/oidc_handler.py,
25 synapse/handlers/pagination.py,
26 synapse/handlers/presence.py,
27 synapse/handlers/room.py,
28 synapse/handlers/room_member.py,
29 synapse/handlers/room_member_worker.py,
30 synapse/handlers/saml_handler.py,
31 synapse/handlers/sync.py,
32 synapse/handlers/ui_auth,
33 synapse/http/federation/well_known_resolver.py,
34 synapse/http/server.py,
35 synapse/http/site.py,
36 synapse/logging/,
37 synapse/metrics,
38 synapse/module_api,
39 synapse/notifier.py,
40 synapse/push/pusherpool.py,
41 synapse/push/push_rule_evaluator.py,
42 synapse/replication,
43 synapse/rest,
44 synapse/server.py,
45 synapse/server_notices,
46 synapse/spam_checker_api,
47 synapse/state,
48 synapse/storage/databases/main/stream.py,
49 synapse/storage/databases/main/ui_auth.py,
50 synapse/storage/database.py,
51 synapse/storage/engines,
52 synapse/storage/state.py,
53 synapse/storage/util,
54 synapse/streams,
55 synapse/types.py,
56 synapse/util/caches/descriptors.py,
57 synapse/util/caches/stream_change_cache.py,
58 synapse/util/metrics.py,
59 tests/replication,
60 tests/test_utils,
61 tests/rest/client/v2_alpha/test_auth.py,
62 tests/util/test_stream_change_cache.py
863
964 [mypy-pymacaroons.*]
1065 ignore_missing_imports = True
8888 "redactions": ["have_censored"],
8989 "room_stats_state": ["is_federatable"],
9090 "local_media_repository": ["safe_from_quarantine"],
91 "users": ["shadow_banned"],
9192 }
9293
9394
2020 import base64
2121 import json
2222 import sys
23 from typing import Any, Optional
2324 from urllib import parse as urlparse
2425
2526 import nacl.signing
2627 import requests
28 import signedjson.types
2729 import srvlookup
2830 import yaml
2931 from requests.adapters import HTTPAdapter
6870 ).encode("UTF-8")
6971
7072
71 def sign_json(json_object, signing_key, signing_name):
73 def sign_json(
74 json_object: Any, signing_key: signedjson.types.SigningKey, signing_name: str
75 ) -> Any:
7276 signatures = json_object.pop("signatures", {})
7377 unsigned = json_object.pop("unsigned", None)
7478
121125 return keys
122126
123127
124 def request_json(method, origin_name, origin_key, destination, path, content):
128 def request(
129 method: Optional[str],
130 origin_name: str,
131 origin_key: signedjson.types.SigningKey,
132 destination: str,
133 path: str,
134 content: Optional[str],
135 ) -> requests.Response:
125136 if method is None:
126137 if content is None:
127138 method = "GET"
158169 if method == "POST":
159170 headers["Content-Type"] = "application/json"
160171
161 result = s.request(
162 method=method, url=dest, headers=headers, verify=False, data=content
163 )
164 sys.stderr.write("Status Code: %d\n" % (result.status_code,))
165 return result.json()
172 return s.request(
173 method=method,
174 url=dest,
175 headers=headers,
176 verify=False,
177 data=content,
178 stream=True,
179 )
166180
167181
168182 def main():
221235 with open(args.signing_key_path) as f:
222236 key = read_signing_keys(f)[0]
223237
224 result = request_json(
238 result = request(
225239 args.method,
226240 args.server_name,
227241 key,
230244 content=args.body,
231245 )
232246
233 json.dump(result, sys.stdout)
247 sys.stderr.write("Status Code: %d\n" % (result.status_code,))
248
249 for chunk in result.iter_content():
250 # we write raw utf8 to stdout.
251 sys.stdout.buffer.write(chunk)
252
234253 print("")
235254
236255
1414 from synapse.storage.signatures import SignatureStore
1515
1616
17 class Store(object):
17 class Store:
1818 _get_pdu_tuples = PduStore.__dict__["_get_pdu_tuples"]
1919 _get_pdu_content_hashes_txn = SignatureStore.__dict__["_get_pdu_content_hashes_txn"]
2020 _get_prev_pdu_hashes_txn = SignatureStore.__dict__["_get_prev_pdu_hashes_txn"]
0 # -*- coding: utf-8 -*-
1 # Copyright 2020 The Matrix.org Foundation C.I.C.
2 #
3 # Licensed under the Apache License, Version 2.0 (the "License");
4 # you may not use this file except in compliance with the License.
5 # You may obtain a copy of the License at
6 #
7 # http://www.apache.org/licenses/LICENSE-2.0
8 #
9 # Unless required by applicable law or agreed to in writing, software
10 # distributed under the License is distributed on an "AS IS" BASIS,
11 # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 # See the License for the specific language governing permissions and
13 # limitations under the License.
14
15 """This is a mypy plugin for Synpase to deal with some of the funky typing that
16 can crop up, e.g the cache descriptors.
17 """
18
19 from typing import Callable, Optional
20
21 from mypy.plugin import MethodSigContext, Plugin
22 from mypy.typeops import bind_self
23 from mypy.types import CallableType
24
25
26 class SynapsePlugin(Plugin):
27 def get_method_signature_hook(
28 self, fullname: str
29 ) -> Optional[Callable[[MethodSigContext], CallableType]]:
30 if fullname.startswith(
31 "synapse.util.caches.descriptors._CachedFunction.__call__"
32 ):
33 return cached_function_method_signature
34 return None
35
36
37 def cached_function_method_signature(ctx: MethodSigContext) -> CallableType:
38 """Fixes the `_CachedFunction.__call__` signature to be correct.
39
40 It already has *almost* the correct signature, except:
41
42 1. the `self` argument needs to be marked as "bound"; and
43 2. any `cache_context` argument should be removed.
44 """
45
46 # First we mark this as a bound function signature.
47 signature = bind_self(ctx.default_signature)
48
49 # Secondly, we remove any "cache_context" args.
50 #
51 # Note: We should be only doing this if `cache_context=True` is set, but if
52 # it isn't then the code will raise an exception when its called anyway, so
53 # its not the end of the world.
54 context_arg_index = None
55 for idx, name in enumerate(signature.arg_names):
56 if name == "cache_context":
57 context_arg_index = idx
58 break
59
60 if context_arg_index:
61 arg_types = list(signature.arg_types)
62 arg_types.pop(context_arg_index)
63
64 arg_names = list(signature.arg_names)
65 arg_names.pop(context_arg_index)
66
67 arg_kinds = list(signature.arg_kinds)
68 arg_kinds.pop(context_arg_index)
69
70 signature = signature.copy_modified(
71 arg_types=arg_types, arg_names=arg_names, arg_kinds=arg_kinds,
72 )
73
74 return signature
75
76
77 def plugin(version: str):
78 # This is the entry point of the plugin, and let's us deal with the fact
79 # that the mypy plugin interface is *not* stable by looking at the version
80 # string.
81 #
82 # However, since we pin the version of mypy Synapse uses in CI, we don't
83 # really care.
84 return SynapsePlugin
0 # -*- coding: utf-8 -*-
1 # Copyright 2020 The Matrix.org Foundation C.I.C.
2 #
3 # Licensed under the Apache License, Version 2.0 (the "License");
4 # you may not use this file except in compliance with the License.
5 # You may obtain a copy of the License at
6 #
7 # http://www.apache.org/licenses/LICENSE-2.0
8 #
9 # Unless required by applicable law or agreed to in writing, software
10 # distributed under the License is distributed on an "AS IS" BASIS,
11 # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 # See the License for the specific language governing permissions and
13 # limitations under the License.
14
15 # Stub for frozendict.
16
17 from typing import (
18 Any,
19 Hashable,
20 Iterable,
21 Iterator,
22 Mapping,
23 overload,
24 Tuple,
25 TypeVar,
26 )
27
28 _KT = TypeVar("_KT", bound=Hashable) # Key type.
29 _VT = TypeVar("_VT") # Value type.
30
31 class frozendict(Mapping[_KT, _VT]):
32 @overload
33 def __init__(self, **kwargs: _VT) -> None: ...
34 @overload
35 def __init__(self, __map: Mapping[_KT, _VT], **kwargs: _VT) -> None: ...
36 @overload
37 def __init__(
38 self, __iterable: Iterable[Tuple[_KT, _VT]], **kwargs: _VT
39 ) -> None: ...
40 def __getitem__(self, key: _KT) -> _VT: ...
41 def __contains__(self, key: Any) -> bool: ...
42 def copy(self, **add_or_replace: Any) -> frozendict: ...
43 def __iter__(self) -> Iterator[_KT]: ...
44 def __len__(self) -> int: ...
45 def __repr__(self) -> str: ...
46 def __hash__(self) -> int: ...
4747 except ImportError:
4848 pass
4949
50 __version__ = "1.19.1"
50 __version__ = "1.20.1"
5151
5252 if bool(os.environ.get("SYNAPSE_TEST_PATCH_LOG_CONTEXTS", False)):
5353 # We import here so that we don't have to install a bunch of deps when
5757 pass
5858
5959
60 class Auth(object):
60 class Auth:
6161 """
6262 FIXME: This class contains a mix of functions for authenticating users
6363 of our client-server API and authenticating events added to room graphs.
212212 user = user_info["user"]
213213 token_id = user_info["token_id"]
214214 is_guest = user_info["is_guest"]
215 shadow_banned = user_info["shadow_banned"]
215216
216217 # Deny the request if the user account has expired.
217218 if self._account_validity.enabled and not allow_expired:
251252 opentracing.set_tag("device_id", device_id)
252253
253254 return synapse.types.create_requester(
254 user, token_id, is_guest, device_id, app_service=app_service
255 user,
256 token_id,
257 is_guest,
258 shadow_banned,
259 device_id,
260 app_service=app_service,
255261 )
256262 except KeyError:
257263 raise MissingClientTokenError()
296302 dict that includes:
297303 `user` (UserID)
298304 `is_guest` (bool)
305 `shadow_banned` (bool)
299306 `token_id` (int|None): access token id. May be None if guest
300307 `device_id` (str|None): device corresponding to access token
301308 Raises:
355362 ret = {
356363 "user": user,
357364 "is_guest": True,
365 "shadow_banned": False,
358366 "token_id": None,
359367 # all guests get the same device id
360368 "device_id": GUEST_DEVICE_ID,
364372 ret = {
365373 "user": user,
366374 "is_guest": False,
375 "shadow_banned": False,
367376 "token_id": None,
368377 "device_id": None,
369378 }
487496 "user": UserID.from_string(ret.get("name")),
488497 "token_id": ret.get("token_id", None),
489498 "is_guest": False,
499 "shadow_banned": ret.get("shadow_banned"),
490500 "device_id": ret.get("device_id"),
491501 "valid_until_ms": ret.get("valid_until_ms"),
492502 }
2121 logger = logging.getLogger(__name__)
2222
2323
24 class AuthBlocking(object):
24 class AuthBlocking:
2525 def __init__(self, hs):
2626 self.store = hs.get_datastore()
2727
2727 MAX_USERID_LENGTH = 255
2828
2929
30 class Membership(object):
30 class Membership:
3131
3232 """Represents the membership states of a user in a room."""
3333
3939 LIST = (INVITE, JOIN, KNOCK, LEAVE, BAN)
4040
4141
42 class PresenceState(object):
42 class PresenceState:
4343 """Represents the presence state of a user."""
4444
4545 OFFLINE = "offline"
4747 ONLINE = "online"
4848
4949
50 class JoinRules(object):
50 class JoinRules:
5151 PUBLIC = "public"
5252 KNOCK = "knock"
5353 INVITE = "invite"
5454 PRIVATE = "private"
5555
5656
57 class LoginType(object):
57 class LoginType:
5858 PASSWORD = "m.login.password"
5959 EMAIL_IDENTITY = "m.login.email.identity"
6060 MSISDN = "m.login.msisdn"
6464 DUMMY = "m.login.dummy"
6565
6666
67 class EventTypes(object):
67 class EventTypes:
6868 Member = "m.room.member"
6969 Create = "m.room.create"
7070 Tombstone = "m.room.tombstone"
9595 Presence = "m.presence"
9696
9797
98 class RejectedReason(object):
98 class RejectedReason:
9999 AUTH_ERROR = "auth_error"
100100
101101
102 class RoomCreationPreset(object):
102 class RoomCreationPreset:
103103 PRIVATE_CHAT = "private_chat"
104104 PUBLIC_CHAT = "public_chat"
105105 TRUSTED_PRIVATE_CHAT = "trusted_private_chat"
106106
107107
108 class ThirdPartyEntityKind(object):
108 class ThirdPartyEntityKind:
109109 USER = "user"
110110 LOCATION = "location"
111111
114114 ServerNoticeLimitReached = "m.server_notice.usage_limit_reached"
115115
116116
117 class UserTypes(object):
117 class UserTypes:
118118 """Allows for user type specific behaviour. With the benefit of hindsight
119119 'admin' and 'guest' users should also be UserTypes. Normal users are type None
120120 """
124124 ALL_USER_TYPES = (SUPPORT, BOT)
125125
126126
127 class RelationTypes(object):
127 class RelationTypes:
128128 """The types of relations known to this server.
129129 """
130130
133133 REFERENCE = "m.reference"
134134
135135
136 class LimitBlockingTypes(object):
136 class LimitBlockingTypes:
137137 """Reasons that a server may be blocked"""
138138
139139 MONTHLY_ACTIVE_USER = "monthly_active_user"
140140 HS_DISABLED = "hs_disabled"
141141
142142
143 class EventContentFields(object):
143 class EventContentFields:
144144 """Fields found in events' content, regardless of type."""
145145
146146 # Labels for the event, cf https://github.com/matrix-org/matrix-doc/pull/2326
151151 SELF_DESTRUCT_AFTER = "org.matrix.self_destruct_after"
152152
153153
154 class RoomEncryptionAlgorithms(object):
154 class RoomEncryptionAlgorithms:
155155 MEGOLM_V1_AES_SHA2 = "m.megolm.v1.aes-sha2"
156156 DEFAULT = MEGOLM_V1_AES_SHA2
2020 from http import HTTPStatus
2121 from typing import Dict, List, Optional, Union
2222
23 from canonicaljson import json
24
2523 from twisted.web import http
24
25 from synapse.util import json_decoder
2626
2727 if typing.TYPE_CHECKING:
2828 from synapse.types import JsonDict
3030 logger = logging.getLogger(__name__)
3131
3232
33 class Codes(object):
33 class Codes:
3434 UNRECOGNIZED = "M_UNRECOGNIZED"
3535 UNAUTHORIZED = "M_UNAUTHORIZED"
3636 FORBIDDEN = "M_FORBIDDEN"
592592 # try to parse the body as json, to get better errcode/msg, but
593593 # default to M_UNKNOWN with the HTTP status as the error text
594594 try:
595 j = json.loads(self.response.decode("utf-8"))
595 j = json_decoder.decode(self.response.decode("utf-8"))
596596 except ValueError:
597597 j = {}
598598
603603 errmsg = j.pop("error", self.msg)
604604
605605 return ProxiedRequestError(self.code, errmsg, errcode, j)
606
607
608 class ShadowBanError(Exception):
609 """
610 Raised when a shadow-banned user attempts to perform an action.
611
612 This should be caught and a proper "fake" success response sent to the user.
613 """
2222
2323 from synapse.api.constants import EventContentFields
2424 from synapse.api.errors import SynapseError
25 from synapse.storage.presence import UserPresenceState
25 from synapse.api.presence import UserPresenceState
2626 from synapse.types import RoomID, UserID
2727
2828 FILTER_SCHEMA = {
129129 return UserID.from_string(user_id_str)
130130
131131
132 class Filtering(object):
132 class Filtering:
133133 def __init__(self, hs):
134134 super(Filtering, self).__init__()
135135 self.store = hs.get_datastore()
167167 raise SynapseError(400, str(e))
168168
169169
170 class FilterCollection(object):
170 class FilterCollection:
171171 def __init__(self, filter_json):
172172 self._filter_json = filter_json
173173
248248 )
249249
250250
251 class Filter(object):
251 class Filter:
252252 def __init__(self, filter_json):
253253 self.filter_json = filter_json
254254
0 # -*- coding: utf-8 -*-
1 # Copyright 2014-2016 OpenMarket Ltd
2 #
3 # Licensed under the Apache License, Version 2.0 (the "License");
4 # you may not use this file except in compliance with the License.
5 # You may obtain a copy of the License at
6 #
7 # http://www.apache.org/licenses/LICENSE-2.0
8 #
9 # Unless required by applicable law or agreed to in writing, software
10 # distributed under the License is distributed on an "AS IS" BASIS,
11 # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 # See the License for the specific language governing permissions and
13 # limitations under the License.
14
15 from collections import namedtuple
16
17 from synapse.api.constants import PresenceState
18
19
20 class UserPresenceState(
21 namedtuple(
22 "UserPresenceState",
23 (
24 "user_id",
25 "state",
26 "last_active_ts",
27 "last_federation_update_ts",
28 "last_user_sync_ts",
29 "status_msg",
30 "currently_active",
31 ),
32 )
33 ):
34 """Represents the current presence state of the user.
35
36 user_id (str)
37 last_active (int): Time in msec that the user last interacted with server.
38 last_federation_update (int): Time in msec since either a) we sent a presence
39 update to other servers or b) we received a presence update, depending
40 on if is a local user or not.
41 last_user_sync (int): Time in msec that the user last *completed* a sync
42 (or event stream).
43 status_msg (str): User set status message.
44 """
45
46 def as_dict(self):
47 return dict(self._asdict())
48
49 @staticmethod
50 def from_dict(d):
51 return UserPresenceState(**d)
52
53 def copy_and_replace(self, **kwargs):
54 return self._replace(**kwargs)
55
56 @classmethod
57 def default(cls, user_id):
58 """Returns a default presence state.
59 """
60 return cls(
61 user_id=user_id,
62 state=PresenceState.OFFLINE,
63 last_active_ts=0,
64 last_federation_update_ts=0,
65 last_user_sync_ts=0,
66 status_msg=None,
67 currently_active=False,
68 )
2020 from synapse.util import Clock
2121
2222
23 class Ratelimiter(object):
23 class Ratelimiter:
2424 """
2525 Ratelimit actions marked by arbitrary keys.
2626
1717 import attr
1818
1919
20 class EventFormatVersions(object):
20 class EventFormatVersions:
2121 """This is an internal enum for tracking the version of the event format,
2222 independently from the room version.
2323 """
3434 }
3535
3636
37 class StateResolutionVersions(object):
37 class StateResolutionVersions:
3838 """Enum to identify the state resolution algorithms"""
3939
4040 V1 = 1 # room v1 state res
4141 V2 = 2 # MSC1442 state res: room v2 and later
4242
4343
44 class RoomDisposition(object):
44 class RoomDisposition:
4545 STABLE = "stable"
4646 UNSTABLE = "unstable"
4747
4848
4949 @attr.s(slots=True, frozen=True)
50 class RoomVersion(object):
50 class RoomVersion:
5151 """An object which describes the unique attributes of a room version."""
5252
5353 identifier = attr.ib() # str; the identifier for this version
6868 limit_notifications_power_levels = attr.ib(type=bool)
6969
7070
71 class RoomVersions(object):
71 class RoomVersions:
7272 V1 = RoomVersion(
7373 "1",
7474 RoomDisposition.STABLE,
3232 LEGACY_MEDIA_PREFIX = "/_matrix/media/v1"
3333
3434
35 class ConsentURIBuilder(object):
35 class ConsentURIBuilder:
3636 def __init__(self, hs_config):
3737 """
3838 Args:
333333 This is to workaround https://twistedmatrix.com/trac/ticket/9620, where we
334334 can run out of file descriptors and infinite loop if we attempt to do too
335335 many DNS queries at once
336
337 XXX: I'm confused by this. reactor.nameResolver does not use twisted.names unless
338 you explicitly install twisted.names as the resolver; rather it uses a GAIResolver
339 backed by the reactor's default threadpool (which is limited to 10 threads). So
340 (a) I don't understand why twisted ticket 9620 is relevant, and (b) I don't
341 understand why we would run out of FDs if we did too many lookups at once.
342 -- richvdh 2020/08/29
336343 """
337344 new_resolver = _LimitedHostnameResolver(
338345 reactor.nameResolver, max_dns_requests_in_flight
341348 reactor.installNameResolver(new_resolver)
342349
343350
344 class _LimitedHostnameResolver(object):
351 class _LimitedHostnameResolver:
345352 """Wraps a IHostnameResolver, limiting the number of in-flight DNS lookups.
346353 """
347354
401408 yield deferred
402409
403410
404 class _DeferredResolutionReceiver(object):
411 class _DeferredResolutionReceiver:
405412 """Wraps a IResolutionReceiver and simply resolves the given deferred when
406413 resolution is complete
407414 """
7878 pass
7979
8080
81 @defer.inlineCallbacks
82 def export_data_command(hs, args):
81 async def export_data_command(hs, args):
8382 """Export data for a user.
8483
8584 Args:
9089 user_id = args.user_id
9190 directory = args.output_directory
9291
93 res = yield defer.ensureDeferred(
94 hs.get_handlers().admin_handler.export_user_data(
95 user_id, FileExfiltrationWriter(user_id, directory=directory)
96 )
92 res = await hs.get_handlers().admin_handler.export_user_data(
93 user_id, FileExfiltrationWriter(user_id, directory=directory)
9794 )
9895 print(res)
9996
231228 # We also make sure that `_base.start` gets run before we actually run the
232229 # command.
233230
234 @defer.inlineCallbacks
235 def run(_reactor):
231 async def run():
236232 with LoggingContext("command"):
237 yield _base.start(ss, [])
238 yield args.func(ss, args)
233 _base.start(ss, [])
234 await args.func(ss, args)
239235
240236 _base.start_worker_reactor(
241 "synapse-admin-cmd", config, run_command=lambda: task.react(run)
237 "synapse-admin-cmd",
238 config,
239 run_command=lambda: task.react(lambda _reactor: defer.ensureDeferred(run())),
242240 )
243241
244242
744744 self.send_handler.wake_destination(server)
745745
746746
747 class FederationSenderHandler(object):
747 class FederationSenderHandler:
748748 """Processes the fedration replication stream
749749
750750 This class is only instantiate on the worker responsible for sending outbound
410410
411411 return provision
412412
413 @defer.inlineCallbacks
414 def reprovision_acme():
413 async def reprovision_acme():
415414 """
416415 Provision a certificate from ACME, if required, and reload the TLS
417416 certificate if it's renewed.
418417 """
419 reprovisioned = yield defer.ensureDeferred(do_acme())
418 reprovisioned = await do_acme()
420419 if reprovisioned:
421420 _base.refresh_certificate(hs)
422421
423 @defer.inlineCallbacks
424 def start():
422 async def start():
425423 try:
426424 # Run the ACME provisioning code, if it's enabled.
427425 if hs.config.acme_enabled:
428426 acme = hs.get_acme_handler()
429427 # Start up the webservices which we will respond to ACME
430428 # challenges with, and then provision.
431 yield defer.ensureDeferred(acme.start_listening())
432 yield defer.ensureDeferred(do_acme())
429 await acme.start_listening()
430 await do_acme()
433431
434432 # Check if it needs to be reprovisioned every day.
435433 hs.get_clock().looping_call(reprovision_acme, 24 * 60 * 60 * 1000)
438436 if hs.config.oidc_enabled:
439437 oidc = hs.get_oidc_handler()
440438 # Loading the provider metadata also ensures the provider config is valid.
441 yield defer.ensureDeferred(oidc.load_metadata())
442 yield defer.ensureDeferred(oidc.load_jwks())
439 await oidc.load_metadata()
440 await oidc.load_jwks()
443441
444442 _base.start(hs, config.listeners)
445443
455453 reactor.stop()
456454 sys.exit(1)
457455
458 reactor.callWhenRunning(start)
456 reactor.callWhenRunning(lambda: defer.ensureDeferred(start()))
459457
460458 return hs
461459
1313 # limitations under the License.
1414 import logging
1515 import re
16 from typing import TYPE_CHECKING
1617
1718 from synapse.api.constants import EventTypes
19 from synapse.appservice.api import ApplicationServiceApi
1820 from synapse.types import GroupID, get_domain_from_id
1921 from synapse.util.caches.descriptors import cached
2022
23 if TYPE_CHECKING:
24 from synapse.storage.databases.main import DataStore
25
2126 logger = logging.getLogger(__name__)
2227
2328
24 class ApplicationServiceState(object):
29 class ApplicationServiceState:
2530 DOWN = "down"
2631 UP = "up"
2732
2833
29 class AppServiceTransaction(object):
34 class AppServiceTransaction:
3035 """Represents an application service transaction."""
3136
3237 def __init__(self, service, id, events):
3439 self.id = id
3540 self.events = events
3641
37 def send(self, as_api):
42 async def send(self, as_api: ApplicationServiceApi) -> bool:
3843 """Sends this transaction using the provided AS API interface.
3944
4045 Args:
41 as_api(ApplicationServiceApi): The API to use to send.
46 as_api: The API to use to send.
4247 Returns:
43 An Awaitable which resolves to True if the transaction was sent.
44 """
45 return as_api.push_bulk(
48 True if the transaction was sent.
49 """
50 return await as_api.push_bulk(
4651 service=self.service, events=self.events, txn_id=self.id
4752 )
4853
49 def complete(self, store):
54 async def complete(self, store: "DataStore") -> None:
5055 """Completes this transaction as successful.
5156
5257 Marks this transaction ID on the application service and removes the
5459
5560 Args:
5661 store: The database store to operate on.
57 Returns:
58 A Deferred which resolves to True if the transaction was completed.
59 """
60 return store.complete_appservice_txn(service=self.service, txn_id=self.id)
61
62
63 class ApplicationService(object):
62 """
63 await store.complete_appservice_txn(service=self.service, txn_id=self.id)
64
65
66 class ApplicationService:
6467 """Defines an application service. This definition is mostly what is
6568 provided to the /register AS API.
6669
1313 # limitations under the License.
1414 import logging
1515 import urllib
16 from typing import TYPE_CHECKING, Optional
1617
1718 from prometheus_client import Counter
18
19 from twisted.internet import defer
2019
2120 from synapse.api.constants import EventTypes, ThirdPartyEntityKind
2221 from synapse.api.errors import CodeMessageException
2322 from synapse.events.utils import serialize_event
2423 from synapse.http.client import SimpleHttpClient
25 from synapse.types import ThirdPartyInstanceID
24 from synapse.types import JsonDict, ThirdPartyInstanceID
2625 from synapse.util.caches.response_cache import ResponseCache
26
27 if TYPE_CHECKING:
28 from synapse.appservice import ApplicationService
2729
2830 logger = logging.getLogger(__name__)
2931
162164 logger.warning("query_3pe to %s threw exception %s", uri, ex)
163165 return []
164166
165 def get_3pe_protocol(self, service, protocol):
167 async def get_3pe_protocol(
168 self, service: "ApplicationService", protocol: str
169 ) -> Optional[JsonDict]:
166170 if service.url is None:
167171 return {}
168172
169 @defer.inlineCallbacks
170 def _get():
173 async def _get() -> Optional[JsonDict]:
171174 uri = "%s%s/thirdparty/protocol/%s" % (
172175 service.url,
173176 APP_SERVICE_PREFIX,
174177 urllib.parse.quote(protocol),
175178 )
176179 try:
177 info = yield defer.ensureDeferred(self.get_json(uri, {}))
180 info = await self.get_json(uri, {})
178181
179182 if not _is_valid_3pe_metadata(info):
180183 logger.warning(
195198 return None
196199
197200 key = (service.id, protocol)
198 return self.protocol_meta_cache.wrap(key, _get)
201 return await self.protocol_meta_cache.wrap(key, _get)
199202
200203 async def push_bulk(self, service, events, txn_id=None):
201204 if service.url is None:
5656 logger = logging.getLogger(__name__)
5757
5858
59 class ApplicationServiceScheduler(object):
59 class ApplicationServiceScheduler:
6060 """ Public facing API for this module. Does the required DI to tie the
6161 components together. This also serves as the "event_pool", which in this
6262 case is a simple array.
8585 self.queuer.enqueue(service, event)
8686
8787
88 class _ServiceQueuer(object):
88 class _ServiceQueuer:
8989 """Queue of events waiting to be sent to appservices.
9090
9191 Groups events into transactions per-appservice, and sends them on to the
132132 self.requests_in_flight.discard(service.id)
133133
134134
135 class _TransactionController(object):
135 class _TransactionController:
136136 """Transaction manager.
137137
138138 Builds AppServiceTransactions and runs their lifecycle. Also starts a Recoverer
208208 return state == ApplicationServiceState.UP or state is None
209209
210210
211 class _Recoverer(object):
211 class _Recoverer:
212212 """Manages retries and backoff for a DOWN appservice.
213213
214214 We have one of these for each appservice which is currently considered DOWN.
1717 import argparse
1818 import errno
1919 import os
20 import time
21 import urllib.parse
2022 from collections import OrderedDict
2123 from hashlib import sha256
2224 from textwrap import dedent
23 from typing import Any, List, MutableMapping, Optional
25 from typing import Any, Callable, List, MutableMapping, Optional
2426
2527 import attr
28 import jinja2
29 import pkg_resources
2630 import yaml
2731
2832
8387 return False
8488
8589
86 class Config(object):
90 class Config:
8791 """
8892 A configuration section, containing configuration keys and values.
8993
98102
99103 def __init__(self, root_config=None):
100104 self.root = root_config
105
106 # Get the path to the default Synapse template directory
107 self.default_template_dir = pkg_resources.resource_filename(
108 "synapse", "res/templates"
109 )
101110
102111 def __getattr__(self, item: str) -> Any:
103112 """
183192 with open(file_path) as file_stream:
184193 return file_stream.read()
185194
186
187 class RootConfig(object):
195 def read_templates(
196 self,
197 filenames: List[str],
198 custom_template_directory: Optional[str] = None,
199 autoescape: bool = False,
200 ) -> List[jinja2.Template]:
201 """Load a list of template files from disk using the given variables.
202
203 This function will attempt to load the given templates from the default Synapse
204 template directory. If `custom_template_directory` is supplied, that directory
205 is tried first.
206
207 Files read are treated as Jinja templates. These templates are not rendered yet.
208
209 Args:
210 filenames: A list of template filenames to read.
211
212 custom_template_directory: A directory to try to look for the templates
213 before using the default Synapse template directory instead.
214
215 autoescape: Whether to autoescape variables before inserting them into the
216 template.
217
218 Raises:
219 ConfigError: if the file's path is incorrect or otherwise cannot be read.
220
221 Returns:
222 A list of jinja2 templates.
223 """
224 templates = []
225 search_directories = [self.default_template_dir]
226
227 # The loader will first look in the custom template directory (if specified) for the
228 # given filename. If it doesn't find it, it will use the default template dir instead
229 if custom_template_directory:
230 # Check that the given template directory exists
231 if not self.path_exists(custom_template_directory):
232 raise ConfigError(
233 "Configured template directory does not exist: %s"
234 % (custom_template_directory,)
235 )
236
237 # Search the custom template directory as well
238 search_directories.insert(0, custom_template_directory)
239
240 loader = jinja2.FileSystemLoader(search_directories)
241 env = jinja2.Environment(loader=loader, autoescape=autoescape)
242
243 # Update the environment with our custom filters
244 env.filters.update(
245 {
246 "format_ts": _format_ts_filter,
247 "mxc_to_http": _create_mxc_to_http_filter(self.public_baseurl),
248 }
249 )
250
251 for filename in filenames:
252 # Load the template
253 template = env.get_template(filename)
254 templates.append(template)
255
256 return templates
257
258
259 def _format_ts_filter(value: int, format: str):
260 return time.strftime(format, time.localtime(value / 1000))
261
262
263 def _create_mxc_to_http_filter(public_baseurl: str) -> Callable:
264 """Create and return a jinja2 filter that converts MXC urls to HTTP
265
266 Args:
267 public_baseurl: The public, accessible base URL of the homeserver
268 """
269
270 def mxc_to_http_filter(value, width, height, resize_method="crop"):
271 if value[0:6] != "mxc://":
272 return ""
273
274 server_and_media_id = value[6:]
275 fragment = None
276 if "#" in server_and_media_id:
277 server_and_media_id, fragment = server_and_media_id.split("#", 1)
278 fragment = "#" + fragment
279
280 params = {"width": width, "height": height, "method": resize_method}
281 return "%s_matrix/media/v1/thumbnail/%s?%s%s" % (
282 public_baseurl,
283 server_and_media_id,
284 urllib.parse.urlencode(params),
285 fragment or "",
286 )
287
288 return mxc_to_http_filter
289
290
291 class RootConfig:
188292 """
189293 Holder of an application's configuration.
190294
3232 _DEFAULT_EVENT_CACHE_SIZE = "10K"
3333
3434
35 class CacheProperties(object):
35 class CacheProperties:
3636 def __init__(self):
3737 # The default factor size for all caches
3838 self.default_factor_size = float(
2222 from typing import Optional
2323
2424 import attr
25 import pkg_resources
2625
2726 from ._base import Config, ConfigError
2827
9796 if parsed[1] == "":
9897 raise RuntimeError("Invalid notif_from address")
9998
99 # A user-configurable template directory
100100 template_dir = email_config.get("template_dir")
101 # we need an absolute path, because we change directory after starting (and
102 # we don't yet know what auxiliary templates like mail.css we will need).
103 # (Note that loading as package_resources with jinja.PackageLoader doesn't
104 # work for the same reason.)
105 if not template_dir:
106 template_dir = pkg_resources.resource_filename("synapse", "res/templates")
107
108 self.email_template_dir = os.path.abspath(template_dir)
101 if isinstance(template_dir, str):
102 # We need an absolute path, because we change directory after starting (and
103 # we don't yet know what auxiliary templates like mail.css we will need).
104 template_dir = os.path.abspath(template_dir)
105 elif template_dir is not None:
106 # If template_dir is something other than a str or None, warn the user
107 raise ConfigError("Config option email.template_dir must be type str")
109108
110109 self.email_enable_notifs = email_config.get("enable_notifs", False)
111
112 account_validity_config = config.get("account_validity") or {}
113 account_validity_renewal_enabled = account_validity_config.get("renew_at")
114110
115111 self.threepid_behaviour_email = (
116112 # Have Synapse handle the email sending if account_threepid_delegates.email
165161 email_config.get("validation_token_lifetime", "1h")
166162 )
167163
168 if (
169 self.email_enable_notifs
170 or account_validity_renewal_enabled
171 or self.threepid_behaviour_email == ThreepidBehaviour.LOCAL
172 ):
173 # make sure we can import the required deps
174 import bleach
175 import jinja2
176
177 # prevent unused warnings
178 jinja2
179 bleach
180
181164 if self.threepid_behaviour_email == ThreepidBehaviour.LOCAL:
182165 missing = []
183166 if not self.email_notif_from:
195178
196179 # These email templates have placeholders in them, and thus must be
197180 # parsed using a templating engine during a request
198 self.email_password_reset_template_html = email_config.get(
181 password_reset_template_html = email_config.get(
199182 "password_reset_template_html", "password_reset.html"
200183 )
201 self.email_password_reset_template_text = email_config.get(
184 password_reset_template_text = email_config.get(
202185 "password_reset_template_text", "password_reset.txt"
203186 )
204 self.email_registration_template_html = email_config.get(
187 registration_template_html = email_config.get(
205188 "registration_template_html", "registration.html"
206189 )
207 self.email_registration_template_text = email_config.get(
190 registration_template_text = email_config.get(
208191 "registration_template_text", "registration.txt"
209192 )
210 self.email_add_threepid_template_html = email_config.get(
193 add_threepid_template_html = email_config.get(
211194 "add_threepid_template_html", "add_threepid.html"
212195 )
213 self.email_add_threepid_template_text = email_config.get(
196 add_threepid_template_text = email_config.get(
214197 "add_threepid_template_text", "add_threepid.txt"
215198 )
216199
217 self.email_password_reset_template_failure_html = email_config.get(
200 password_reset_template_failure_html = email_config.get(
218201 "password_reset_template_failure_html", "password_reset_failure.html"
219202 )
220 self.email_registration_template_failure_html = email_config.get(
203 registration_template_failure_html = email_config.get(
221204 "registration_template_failure_html", "registration_failure.html"
222205 )
223 self.email_add_threepid_template_failure_html = email_config.get(
206 add_threepid_template_failure_html = email_config.get(
224207 "add_threepid_template_failure_html", "add_threepid_failure.html"
225208 )
226209
227210 # These templates do not support any placeholder variables, so we
228211 # will read them from disk once during setup
229 email_password_reset_template_success_html = email_config.get(
212 password_reset_template_success_html = email_config.get(
230213 "password_reset_template_success_html", "password_reset_success.html"
231214 )
232 email_registration_template_success_html = email_config.get(
215 registration_template_success_html = email_config.get(
233216 "registration_template_success_html", "registration_success.html"
234217 )
235 email_add_threepid_template_success_html = email_config.get(
218 add_threepid_template_success_html = email_config.get(
236219 "add_threepid_template_success_html", "add_threepid_success.html"
237220 )
238221
239 # Check templates exist
240 for f in [
222 # Read all templates from disk
223 (
241224 self.email_password_reset_template_html,
242225 self.email_password_reset_template_text,
243226 self.email_registration_template_html,
247230 self.email_password_reset_template_failure_html,
248231 self.email_registration_template_failure_html,
249232 self.email_add_threepid_template_failure_html,
250 email_password_reset_template_success_html,
251 email_registration_template_success_html,
252 email_add_threepid_template_success_html,
253 ]:
254 p = os.path.join(self.email_template_dir, f)
255 if not os.path.isfile(p):
256 raise ConfigError("Unable to find template file %s" % (p,))
257
258 # Retrieve content of web templates
259 filepath = os.path.join(
260 self.email_template_dir, email_password_reset_template_success_html
261 )
262 self.email_password_reset_template_success_html = self.read_file(
263 filepath, "email.password_reset_template_success_html"
264 )
265 filepath = os.path.join(
266 self.email_template_dir, email_registration_template_success_html
267 )
268 self.email_registration_template_success_html_content = self.read_file(
269 filepath, "email.registration_template_success_html"
270 )
271 filepath = os.path.join(
272 self.email_template_dir, email_add_threepid_template_success_html
273 )
274 self.email_add_threepid_template_success_html_content = self.read_file(
275 filepath, "email.add_threepid_template_success_html"
233 password_reset_template_success_html_template,
234 registration_template_success_html_template,
235 add_threepid_template_success_html_template,
236 ) = self.read_templates(
237 [
238 password_reset_template_html,
239 password_reset_template_text,
240 registration_template_html,
241 registration_template_text,
242 add_threepid_template_html,
243 add_threepid_template_text,
244 password_reset_template_failure_html,
245 registration_template_failure_html,
246 add_threepid_template_failure_html,
247 password_reset_template_success_html,
248 registration_template_success_html,
249 add_threepid_template_success_html,
250 ],
251 template_dir,
252 )
253
254 # Render templates that do not contain any placeholders
255 self.email_password_reset_template_success_html_content = (
256 password_reset_template_success_html_template.render()
257 )
258 self.email_registration_template_success_html_content = (
259 registration_template_success_html_template.render()
260 )
261 self.email_add_threepid_template_success_html_content = (
262 add_threepid_template_success_html_template.render()
276263 )
277264
278265 if self.email_enable_notifs:
289276 % (", ".join(missing),)
290277 )
291278
292 self.email_notif_template_html = email_config.get(
279 notif_template_html = email_config.get(
293280 "notif_template_html", "notif_mail.html"
294281 )
295 self.email_notif_template_text = email_config.get(
282 notif_template_text = email_config.get(
296283 "notif_template_text", "notif_mail.txt"
297284 )
298285
299 for f in self.email_notif_template_text, self.email_notif_template_html:
300 p = os.path.join(self.email_template_dir, f)
301 if not os.path.isfile(p):
302 raise ConfigError("Unable to find email template file %s" % (p,))
286 (
287 self.email_notif_template_html,
288 self.email_notif_template_text,
289 ) = self.read_templates(
290 [notif_template_html, notif_template_text], template_dir,
291 )
303292
304293 self.email_notif_for_new_users = email_config.get(
305294 "notif_for_new_users", True
308297 "client_base_url", email_config.get("riot_base_url", None)
309298 )
310299
311 if account_validity_renewal_enabled:
312 self.email_expiry_template_html = email_config.get(
300 if self.account_validity.renew_by_email_enabled:
301 expiry_template_html = email_config.get(
313302 "expiry_template_html", "notice_expiry.html"
314303 )
315 self.email_expiry_template_text = email_config.get(
304 expiry_template_text = email_config.get(
316305 "expiry_template_text", "notice_expiry.txt"
317306 )
318307
319 for f in self.email_expiry_template_text, self.email_expiry_template_html:
320 p = os.path.join(self.email_template_dir, f)
321 if not os.path.isfile(p):
322 raise ConfigError("Unable to find email template file %s" % (p,))
308 (
309 self.account_validity_template_html,
310 self.account_validity_template_text,
311 ) = self.read_templates(
312 [expiry_template_html, expiry_template_text], template_dir,
313 )
323314
324315 subjects_config = email_config.get("subjects", {})
325316 subjects = {}
399390 # Directory in which Synapse will try to find the template files below.
400391 # If not set, default templates from within the Synapse package will be used.
401392 #
402 # DO NOT UNCOMMENT THIS SETTING unless you want to customise the templates.
403 # If you *do* uncomment it, you will need to make sure that all the templates
404 # below are in the directory.
393 # Do not uncomment this setting unless you want to customise the templates.
405394 #
406395 # Synapse will look for the following templates in this directory:
407396 #
8181
8282
8383 @attr.s
84 class TrustedKeyServer(object):
84 class TrustedKeyServer:
8585 # string: name of the server.
8686 server_name = attr.ib()
8787
2121
2222
2323 @attr.s
24 class MetricsFlags(object):
24 class MetricsFlags:
2525 known_servers = attr.ib(default=False, validator=attr.validators.instance_of(bool))
2626
2727 @classmethod
1616 from ._base import Config
1717
1818
19 class RateLimitConfig(object):
19 class RateLimitConfig:
2020 def __init__(
2121 self,
2222 config: Dict[str, float],
2626 self.burst_count = config.get("burst_count", defaults["burst_count"])
2727
2828
29 class FederationRateLimitConfig(object):
29 class FederationRateLimitConfig:
3030 _items_and_default = {
3131 "window_size": 1000,
3232 "sleep_limit": 10,
2121 logger = logging.Logger(__name__)
2222
2323
24 class RoomDefaultEncryptionTypes(object):
24 class RoomDefaultEncryptionTypes:
2525 """Possible values for the encryption_enabled_by_default_for_room_type config option"""
2626
2727 ALL = "all"
148148 return False
149149
150150
151 class _RoomDirectoryRule(object):
151 class _RoomDirectoryRule:
152152 """Helper class to test whether a room directory action is allowed, like
153153 creating an alias or publishing a room.
154154 """
1717 from typing import Any, List
1818
1919 import attr
20 import jinja2
21 import pkg_resources
2220
2321 from synapse.python_dependencies import DependencyException, check_requirements
2422 from synapse.util.module_loader import load_module, load_python_module
170168 saml2_config.get("saml_session_lifetime", "15m")
171169 )
172170
173 template_dir = saml2_config.get("template_dir")
174 if not template_dir:
175 template_dir = pkg_resources.resource_filename("synapse", "res/templates",)
176
177 loader = jinja2.FileSystemLoader(template_dir)
178 # enable auto-escape here, to having to remember to escape manually in the
179 # template
180 env = jinja2.Environment(loader=loader, autoescape=True)
181 self.saml2_error_html_template = env.get_template("saml_error.html")
171 # We enable autoescape here as the message may potentially come from a
172 # remote resource
173 self.saml2_error_html_template = self.read_templates(
174 ["saml_error.html"], saml2_config.get("template_dir"), autoescape=True
175 )[0]
182176
183177 def _default_saml_config_dict(
184178 self, required_attributes: set, optional_attributes: set
2525
2626 from synapse.api.room_versions import KNOWN_ROOM_VERSIONS
2727 from synapse.http.endpoint import parse_and_validate_server_name
28 from synapse.python_dependencies import DependencyException, check_requirements
2928
3029 from ._base import Config, ConfigError
3130
424423 self.gc_thresholds = read_gc_thresholds(config.get("gc_thresholds", None))
425424
426425 @attr.s
427 class LimitRemoteRoomsConfig(object):
426 class LimitRemoteRoomsConfig:
428427 enabled = attr.ib(
429428 validator=attr.validators.instance_of(bool), default=False
430429 )
506505 ),
507506 )
508507 )
509
510 _check_resource_config(self.listeners)
511508
512509 self.cleanup_extremities_with_dummy_events = config.get(
513510 "cleanup_extremities_with_dummy_events", True
963960 # min_lifetime: 1d
964961 # max_lifetime: 1y
965962
966 # Retention policy limits. If set, a user won't be able to send a
967 # 'm.room.retention' event which features a 'min_lifetime' or a 'max_lifetime'
968 # that's not within this range. This is especially useful in closed federations,
969 # in which server admins can make sure every federating server applies the same
970 # rules.
963 # Retention policy limits. If set, and the state of a room contains a
964 # 'm.room.retention' event in its state which contains a 'min_lifetime' or a
965 # 'max_lifetime' that's out of these bounds, Synapse will cap the room's policy
966 # to these limits when running purge jobs.
971967 #
972968 #allowed_lifetime_min: 1d
973969 #allowed_lifetime_max: 1y
993989 # (e.g. every 12h), but not want that purge to be performed by a job that's
994990 # iterating over every room it knows, which could be heavy on the server.
995991 #
992 # If any purge job is configured, it is strongly recommended to have at least
993 # a single job with neither 'shortest_max_lifetime' nor 'longest_max_lifetime'
994 # set, or one job without 'shortest_max_lifetime' and one job without
995 # 'longest_max_lifetime' set. Otherwise some rooms might be ignored, even if
996 # 'allowed_lifetime_min' and 'allowed_lifetime_max' are set, because capping a
997 # room's policy to these values is done after the policies are retrieved from
998 # Synapse's database (which is done using the range specified in a purge job's
999 # configuration).
1000 #
9961001 #purge_jobs:
997 # - shortest_max_lifetime: 1d
998 # longest_max_lifetime: 3d
1002 # - longest_max_lifetime: 3d
9991003 # interval: 12h
10001004 # - shortest_max_lifetime: 3d
1001 # longest_max_lifetime: 1y
10021005 # interval: 1d
10031006
10041007 # Inhibits the /requestToken endpoints from returning an error that might leak
11321135 if name == "webclient":
11331136 logger.warning(NO_MORE_WEB_CLIENT_WARNING)
11341137 return
1135
1136
1137 def _check_resource_config(listeners: Iterable[ListenerConfig]) -> None:
1138 resource_names = {
1139 res_name
1140 for listener in listeners
1141 if listener.http_options
1142 for res in listener.http_options.resources
1143 for res_name in res.names
1144 }
1145
1146 for resource in resource_names:
1147 if resource == "consent":
1148 try:
1149 check_requirements("resources.consent")
1150 except DependencyException as e:
1151 raise ConfigError(e.message)
1111 # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
1212 # See the License for the specific language governing permissions and
1313 # limitations under the License.
14 import os
1514 from typing import Any, Dict
16
17 import pkg_resources
1815
1916 from ._base import Config
2017
2825 def read_config(self, config, **kwargs):
2926 sso_config = config.get("sso") or {} # type: Dict[str, Any]
3027
31 # Pick a template directory in order of:
32 # * The sso-specific template_dir
33 # * /path/to/synapse/install/res/templates
28 # The sso-specific template_dir
3429 template_dir = sso_config.get("template_dir")
35 if not template_dir:
36 template_dir = pkg_resources.resource_filename("synapse", "res/templates",)
3730
38 self.sso_template_dir = template_dir
39 self.sso_account_deactivated_template = self.read_file(
40 os.path.join(self.sso_template_dir, "sso_account_deactivated.html"),
41 "sso_account_deactivated_template",
31 # Read templates from disk
32 (
33 self.sso_redirect_confirm_template,
34 self.sso_auth_confirm_template,
35 self.sso_error_template,
36 sso_account_deactivated_template,
37 sso_auth_success_template,
38 ) = self.read_templates(
39 [
40 "sso_redirect_confirm.html",
41 "sso_auth_confirm.html",
42 "sso_error.html",
43 "sso_account_deactivated.html",
44 "sso_auth_success.html",
45 ],
46 template_dir,
4247 )
43 self.sso_auth_success_template = self.read_file(
44 os.path.join(self.sso_template_dir, "sso_auth_success.html"),
45 "sso_auth_success_template",
48
49 # These templates have no placeholders, so render them here
50 self.sso_account_deactivated_template = (
51 sso_account_deactivated_template.render()
4652 )
53 self.sso_auth_success_template = sso_auth_success_template.render()
4754
4855 self.sso_client_whitelist = sso_config.get("client_whitelist") or []
4956
8282
8383
8484 @implementer(IPolicyForHTTPS)
85 class FederationPolicyForHTTPS(object):
85 class FederationPolicyForHTTPS:
8686 """Factory for Twisted SSLClientConnectionCreators that are used to make connections
8787 to remote servers for federation.
8888
151151
152152
153153 @implementer(IPolicyForHTTPS)
154 class RegularPolicyForHTTPS(object):
154 class RegularPolicyForHTTPS:
155155 """Factory for Twisted SSLClientConnectionCreators that are used to make connections
156156 to remote servers, for other than federation.
157157
188188
189189
190190 @implementer(IOpenSSLClientConnectionCreator)
191 class SSLClientConnectionCreator(object):
191 class SSLClientConnectionCreator:
192192 """Creates openssl connection objects for client connections.
193193
194194 Replaces twisted.internet.ssl.ClientTLSOptions
213213 return connection
214214
215215
216 class ConnectionVerifier(object):
216 class ConnectionVerifier:
217217 """Set the SNI, and do cert verification
218218
219219 This is a thing which is attached to the TLSMemoryBIOProtocol, and is called by
5656
5757
5858 @attr.s(slots=True, cmp=False)
59 class VerifyJsonRequest(object):
59 class VerifyJsonRequest:
6060 """
6161 A request to verify a JSON object.
6262
9595 pass
9696
9797
98 class Keyring(object):
98 class Keyring:
9999 def __init__(self, hs, key_fetchers=None):
100100 self.clock = hs.get_clock()
101101
419419 remaining_requests.difference_update(completed)
420420
421421
422 class KeyFetcher(object):
422 class KeyFetcher:
423423 async def get_keys(self, keys_to_fetch):
424424 """
425425 Args:
455455 return keys
456456
457457
458 class BaseV2KeyFetcher(object):
458 class BaseV2KeyFetcher:
459459 def __init__(self, hs):
460460 self.store = hs.get_datastore()
461461 self.config = hs.get_config()
756756 except Exception:
757757 logger.exception("Error getting keys %s from %s", key_ids, server_name)
758758
759 return await yieldable_gather_results(
760 get_key, keys_to_fetch.items()
761 ).addCallback(lambda _: results)
759 await yieldable_gather_results(get_key, keys_to_fetch.items())
760 return results
762761
763762 async def get_server_verify_key_v2_direct(self, server_name, key_ids):
764763 """
768767 key_ids (iterable[str]):
769768
770769 Returns:
771 Deferred[dict[str, FetchKeyResult]]: map from key ID to lookup result
770 dict[str, FetchKeyResult]: map from key ID to lookup result
772771
773772 Raises:
774773 KeyLookupError if there was a problem making the lookup
4646 Args:
4747 room_version_obj: the version of the room
4848 event: the event being checked.
49 auth_events (dict: event-key -> event): the existing room state.
49 auth_events: the existing room state.
5050
5151 Raises:
5252 AuthError if the checks fail
1717 import abc
1818 import os
1919 from distutils.util import strtobool
20 from typing import Dict, Optional, Type
20 from typing import Dict, Optional, Tuple, Type
2121
2222 from unpaddedbase64 import encode_base64
2323
9595 return instance._dict.get(self.key, self.default)
9696
9797
98 class _EventInternalMetadata(object):
98 class _EventInternalMetadata:
9999 __slots__ = ["_dict"]
100100
101101 def __init__(self, internal_metadata_dict: JsonDict):
119119 # be here
120120 before = DictProperty("before") # type: str
121121 after = DictProperty("after") # type: str
122 order = DictProperty("order") # type: int
122 order = DictProperty("order") # type: Tuple[int, int]
123123
124124 def get_dict(self) -> JsonDict:
125125 return dict(self._dict)
132132 rejection. This is needed as those events are marked as outliers, but
133133 they still need to be processed as if they're new events (e.g. updating
134134 invite state in the database, relaying to clients, etc).
135
136 (Added in synapse 0.99.0, so may be unreliable for events received before that)
135137 """
136138 return self._dict.get("out_of_band_membership", False)
137139
1111 # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
1212 # See the License for the specific language governing permissions and
1313 # limitations under the License.
14 from typing import Optional
14 from typing import Any, Dict, List, Optional, Tuple, Union
1515
1616 import attr
1717 from nacl.signing import SigningKey
3535
3636
3737 @attr.s(slots=True, cmp=False, frozen=True)
38 class EventBuilder(object):
38 class EventBuilder:
3939 """A format independent event builder used to build up the event content
4040 before signing the event.
4141
9696 def is_state(self):
9797 return self._state_key is not None
9898
99 async def build(self, prev_event_ids):
99 async def build(self, prev_event_ids: List[str]) -> EventBase:
100100 """Transform into a fully signed and hashed event
101101
102102 Args:
103 prev_event_ids (list[str]): The event IDs to use as the prev events
103 prev_event_ids: The event IDs to use as the prev events
104104
105105 Returns:
106 FrozenEvent
106 The signed and hashed event.
107107 """
108108
109109 state_ids = await self._state.get_current_state_ids(
113113
114114 format_version = self.room_version.event_format
115115 if format_version == EventFormatVersions.V1:
116 auth_events = await self._store.add_event_hashes(auth_ids)
117 prev_events = await self._store.add_event_hashes(prev_event_ids)
116 # The types of auth/prev events changes between event versions.
117 auth_events = await self._store.add_event_hashes(
118 auth_ids
119 ) # type: Union[List[str], List[Tuple[str, Dict[str, str]]]]
120 prev_events = await self._store.add_event_hashes(
121 prev_event_ids
122 ) # type: Union[List[str], List[Tuple[str, Dict[str, str]]]]
118123 else:
119124 auth_events = auth_ids
120125 prev_events = prev_event_ids
137142 "unsigned": self.unsigned,
138143 "depth": depth,
139144 "prev_state": [],
140 }
145 } # type: Dict[str, Any]
141146
142147 if self.is_state():
143148 event_dict["state_key"] = self._state_key
158163 )
159164
160165
161 class EventBuilderFactory(object):
166 class EventBuilderFactory:
162167 def __init__(self, hs):
163168 self.clock = hs.get_clock()
164169 self.hostname = hs.hostname
1414 # limitations under the License.
1515
1616 import inspect
17 from typing import Any, Dict, List
17 from typing import Any, Dict, List, Optional, Tuple
1818
19 from synapse.spam_checker_api import SpamCheckerApi
19 from synapse.spam_checker_api import RegistrationBehaviour, SpamCheckerApi
20 from synapse.types import Collection
2021
2122 MYPY = False
2223 if MYPY:
2324 import synapse.server
2425
2526
26 class SpamChecker(object):
27 class SpamChecker:
2728 def __init__(self, hs: "synapse.server.HomeServer"):
2829 self.spam_checkers = [] # type: List[Any]
2930
159160 return True
160161
161162 return False
163
164 def check_registration_for_spam(
165 self,
166 email_threepid: Optional[dict],
167 username: Optional[str],
168 request_info: Collection[Tuple[str, str]],
169 ) -> RegistrationBehaviour:
170 """Checks if we should allow the given registration request.
171
172 Args:
173 email_threepid: The email threepid used for registering, if any
174 username: The request user name, if any
175 request_info: List of tuples of user agent and IP that
176 were used during the registration process.
177
178 Returns:
179 Enum for how the request should be handled
180 """
181
182 for spam_checker in self.spam_checkers:
183 # For backwards compatibility, only run if the method exists on the
184 # spam checker
185 checker = getattr(spam_checker, "check_registration_for_spam", None)
186 if checker:
187 behaviour = checker(email_threepid, username, request_info)
188 assert isinstance(behaviour, RegistrationBehaviour)
189 if behaviour != RegistrationBehaviour.ALLOW:
190 return behaviour
191
192 return RegistrationBehaviour.ALLOW
1717 from synapse.types import Requester
1818
1919
20 class ThirdPartyEventRules(object):
20 class ThirdPartyEventRules:
2121 """Allows server admins to provide a Python module implementing an extra
2222 set of rules to apply when processing events.
2323
321321 return d
322322
323323
324 class EventClientSerializer(object):
324 class EventClientSerializer:
325325 """Serializes events that are to be sent to clients.
326326
327327 This is used for bundling extra information with any events to be sent to
1919 from synapse.types import EventID, RoomID, UserID
2020
2121
22 class EventValidator(object):
22 class EventValidator:
2323 def validate_new(self, event, config):
2424 """Validates the event has roughly the right format
2525
7373 )
7474
7575 if event.type == EventTypes.Retention:
76 self._validate_retention(event, config)
76 self._validate_retention(event)
7777
78 def _validate_retention(self, event, config):
78 def _validate_retention(self, event):
7979 """Checks that an event that defines the retention policy for a room respects the
80 boundaries imposed by the server's administrator.
80 format enforced by the spec.
8181
8282 Args:
8383 event (FrozenEvent): The event to validate.
84 config (Config): The homeserver's configuration.
8584 """
8685 min_lifetime = event.content.get("min_lifetime")
8786 max_lifetime = event.content.get("max_lifetime")
9493 errcode=Codes.BAD_JSON,
9594 )
9695
97 if (
98 config.retention_allowed_lifetime_min is not None
99 and min_lifetime < config.retention_allowed_lifetime_min
100 ):
101 raise SynapseError(
102 code=400,
103 msg=(
104 "'min_lifetime' can't be lower than the minimum allowed"
105 " value enforced by the server's administrator"
106 ),
107 errcode=Codes.BAD_JSON,
108 )
109
110 if (
111 config.retention_allowed_lifetime_max is not None
112 and min_lifetime > config.retention_allowed_lifetime_max
113 ):
114 raise SynapseError(
115 code=400,
116 msg=(
117 "'min_lifetime' can't be greater than the maximum allowed"
118 " value enforced by the server's administrator"
119 ),
120 errcode=Codes.BAD_JSON,
121 )
122
12396 if max_lifetime is not None:
12497 if not isinstance(max_lifetime, int):
12598 raise SynapseError(
12699 code=400,
127100 msg="'max_lifetime' must be an integer",
128 errcode=Codes.BAD_JSON,
129 )
130
131 if (
132 config.retention_allowed_lifetime_min is not None
133 and max_lifetime < config.retention_allowed_lifetime_min
134 ):
135 raise SynapseError(
136 code=400,
137 msg=(
138 "'max_lifetime' can't be lower than the minimum allowed value"
139 " enforced by the server's administrator"
140 ),
141 errcode=Codes.BAD_JSON,
142 )
143
144 if (
145 config.retention_allowed_lifetime_max is not None
146 and max_lifetime > config.retention_allowed_lifetime_max
147 ):
148 raise SynapseError(
149 code=400,
150 msg=(
151 "'max_lifetime' can't be greater than the maximum allowed"
152 " value enforced by the server's administrator"
153 ),
154101 errcode=Codes.BAD_JSON,
155102 )
156103
3838 logger = logging.getLogger(__name__)
3939
4040
41 class FederationBase(object):
41 class FederationBase:
4242 def __init__(self, hs):
4343 self.hs = hs
4444
5353 from synapse.federation.federation_base import FederationBase, event_from_pdu_json
5454 from synapse.logging.context import make_deferred_yieldable, preserve_fn
5555 from synapse.logging.utils import log_function
56 from synapse.types import JsonDict
56 from synapse.types import JsonDict, get_domain_from_id
5757 from synapse.util import unwrapFirstError
5858 from synapse.util.caches.expiringcache import ExpiringCache
5959 from synapse.util.retryutils import NotRetryingDestination
216216 for p in transaction_data["pdus"]
217217 ]
218218
219 # FIXME: We should handle signature failures more gracefully.
220 pdus[:] = await make_deferred_yieldable(
221 defer.gatherResults(
222 self._check_sigs_and_hashes(room_version, pdus), consumeErrors=True,
223 ).addErrback(unwrapFirstError)
219 # Check signatures and hash of pdus, removing any from the list that fail checks
220 pdus[:] = await self._check_sigs_and_hash_and_fetch(
221 dest, pdus, outlier=True, room_version=room_version
224222 )
225223
226224 return pdus
385383 pdu.event_id, allow_rejected=True, allow_none=True
386384 )
387385
388 if not res and pdu.origin != origin:
386 pdu_origin = get_domain_from_id(pdu.sender)
387 if not res and pdu_origin != origin:
389388 try:
390389 res = await self.get_pdu(
391 destinations=[pdu.origin],
390 destinations=[pdu_origin],
392391 event_id=pdu.event_id,
393392 room_version=room_version,
394393 outlier=outlier,
2727 Union,
2828 )
2929
30 from canonicaljson import json
3130 from prometheus_client import Counter, Histogram
3231
3332 from twisted.internet import defer
6261 ReplicationGetQueryRestServlet,
6362 )
6463 from synapse.types import JsonDict, get_domain_from_id
65 from synapse.util import glob_to_regex, unwrapFirstError
64 from synapse.util import glob_to_regex, json_decoder, unwrapFirstError
6665 from synapse.util.async_helpers import Linearizer, concurrently_execute
6766 from synapse.util.caches.response_cache import ResponseCache
6867
9796 self.state = hs.get_state_handler()
9897
9998 self.device_handler = hs.get_device_handler()
99 self._federation_ratelimiter = hs.get_federation_ratelimiter()
100100
101101 self._server_linearizer = Linearizer("fed_server")
102102 self._transaction_linearizer = Linearizer("fed_txn_handler")
103
104 # We cache results for transaction with the same ID
105 self._transaction_resp_cache = ResponseCache(
106 hs, "fed_txn_handler", timeout_ms=30000
107 )
103108
104109 self.transaction_actions = TransactionActions(self.store)
105110
135140 request_time = self._clock.time_msec()
136141
137142 transaction = Transaction(**transaction_data)
138
139 if not transaction.transaction_id: # type: ignore
143 transaction_id = transaction.transaction_id # type: ignore
144
145 if not transaction_id:
140146 raise Exception("Transaction missing transaction_id")
141147
142 logger.debug("[%s] Got transaction", transaction.transaction_id) # type: ignore
143
144 # use a linearizer to ensure that we don't process the same transaction
145 # multiple times in parallel.
146 with (
147 await self._transaction_linearizer.queue(
148 (origin, transaction.transaction_id) # type: ignore
149 )
150 ):
151 result = await self._handle_incoming_transaction(
152 origin, transaction, request_time
153 )
148 logger.debug("[%s] Got transaction", transaction_id)
149
150 # We wrap in a ResponseCache so that we de-duplicate retried
151 # transactions.
152 return await self._transaction_resp_cache.wrap(
153 (origin, transaction_id),
154 self._on_incoming_transaction_inner,
155 origin,
156 transaction,
157 request_time,
158 )
159
160 async def _on_incoming_transaction_inner(
161 self, origin: str, transaction: Transaction, request_time: int
162 ) -> Tuple[int, Dict[str, Any]]:
163 # Use a linearizer to ensure that transactions from a remote are
164 # processed in order.
165 with await self._transaction_linearizer.queue(origin):
166 # We rate limit here *after* we've queued up the incoming requests,
167 # so that we don't fill up the ratelimiter with blocked requests.
168 #
169 # This is important as the ratelimiter allows N concurrent requests
170 # at a time, and only starts ratelimiting if there are more requests
171 # than that being processed at a time. If we queued up requests in
172 # the linearizer/response cache *after* the ratelimiting then those
173 # queued up requests would count as part of the allowed limit of N
174 # concurrent requests.
175 with self._federation_ratelimiter.ratelimit(origin) as d:
176 await d
177
178 result = await self._handle_incoming_transaction(
179 origin, transaction, request_time
180 )
154181
155182 return result
156183
550577 for device_id, keys in device_keys.items():
551578 for key_id, json_str in keys.items():
552579 json_result.setdefault(user_id, {})[device_id] = {
553 key_id: json.loads(json_str)
580 key_id: json_decoder.decode(json_str)
554581 }
555582
556583 logger.info(
785812 return regex.match(server_name)
786813
787814
788 class FederationHandlerRegistry(object):
815 class FederationHandlerRegistry:
789816 """Allows classes to register themselves as handlers for a given EDU or
790817 query type for incoming federation traffic.
791818 """
1919 """
2020
2121 import logging
22 from typing import Optional, Tuple
2223
24 from synapse.federation.units import Transaction
2325 from synapse.logging.utils import log_function
26 from synapse.types import JsonDict
2427
2528 logger = logging.getLogger(__name__)
2629
2730
28 class TransactionActions(object):
31 class TransactionActions:
2932 """ Defines persistence actions that relate to handling Transactions.
3033 """
3134
3336 self.store = datastore
3437
3538 @log_function
36 def have_responded(self, origin, transaction):
37 """ Have we already responded to a transaction with the same id and
39 async def have_responded(
40 self, origin: str, transaction: Transaction
41 ) -> Optional[Tuple[int, JsonDict]]:
42 """Have we already responded to a transaction with the same id and
3843 origin?
3944
4045 Returns:
41 Deferred: Results in `None` if we have not previously responded to
42 this transaction or a 2-tuple of `(int, dict)` representing the
43 response code and response body.
46 `None` if we have not previously responded to this transaction or a
47 2-tuple of `(int, dict)` representing the response code and response body.
4448 """
45 if not transaction.transaction_id:
49 transaction_id = transaction.transaction_id # type: ignore
50 if not transaction_id:
4651 raise RuntimeError("Cannot persist a transaction with no transaction_id")
4752
48 return self.store.get_received_txn_response(transaction.transaction_id, origin)
53 return await self.store.get_received_txn_response(transaction_id, origin)
4954
5055 @log_function
51 def set_response(self, origin, transaction, code, response):
52 """ Persist how we responded to a transaction.
53
54 Returns:
55 Deferred
56 async def set_response(
57 self, origin: str, transaction: Transaction, code: int, response: JsonDict
58 ) -> None:
59 """Persist how we responded to a transaction.
5660 """
57 if not transaction.transaction_id:
61 transaction_id = transaction.transaction_id # type: ignore
62 if not transaction_id:
5863 raise RuntimeError("Cannot persist a transaction with no transaction_id")
5964
60 return self.store.set_received_txn_response(
61 transaction.transaction_id, origin, code, response
65 await self.store.set_received_txn_response(
66 transaction_id, origin, code, response
6267 )
3636
3737 from twisted.internet import defer
3838
39 from synapse.api.presence import UserPresenceState
3940 from synapse.metrics import LaterGauge
40 from synapse.storage.presence import UserPresenceState
4141 from synapse.util.metrics import Measure
4242
4343 from .units import Edu
4545 logger = logging.getLogger(__name__)
4646
4747
48 class FederationRemoteSendQueue(object):
48 class FederationRemoteSendQueue:
4949 """A drop in replacement for FederationSender"""
5050
5151 def __init__(self, hs):
364364 )
365365
366366
367 class BaseFederationRow(object):
367 class BaseFederationRow:
368368 """Base class for rows to be sent in the federation stream.
369369
370370 Specifies how to identify, serialize and deserialize the different types.
2121
2222 import synapse
2323 import synapse.metrics
24 from synapse.api.presence import UserPresenceState
2425 from synapse.events import EventBase
2526 from synapse.federation.sender.per_destination_queue import PerDestinationQueue
2627 from synapse.federation.sender.transaction_manager import TransactionManager
3839 events_processed_counter,
3940 )
4041 from synapse.metrics.background_process_metrics import run_as_background_process
41 from synapse.storage.presence import UserPresenceState
4242 from synapse.types import ReadReceipt
4343 from synapse.util.metrics import Measure, measure_func
4444
5555 )
5656
5757
58 class FederationSender(object):
58 class FederationSender:
5959 def __init__(self, hs: "synapse.server.HomeServer"):
6060 self.hs = hs
6161 self.server_name = hs.hostname
106106 d.pending_edu_count() for d in self._per_destination_queues.values()
107107 ),
108108 )
109
110 self._order = 1
111109
112110 self._is_processing = False
113111 self._last_poked_id = -1
271269 # a transaction in progress. If we do, stick it in the pending_pdus
272270 # table and we'll get back to it later.
273271
274 order = self._order
275 self._order += 1
276
277272 destinations = set(destinations)
278273 destinations.discard(self.server_name)
279274 logger.debug("Sending to: %s", str(destinations))
285280 sent_pdus_destination_dist_count.inc()
286281
287282 for destination in destinations:
288 self._get_per_destination_queue(destination).send_pdu(pdu, order)
283 self._get_per_destination_queue(destination).send_pdu(pdu)
289284
290285 async def send_read_receipt(self, receipt: ReadReceipt) -> None:
291286 """Send a RR to any other servers in the room
328323 room_id = receipt.room_id
329324
330325 # Work out which remote servers should be poked and poke them.
331 domains = await self.state.get_current_hosts_in_room(room_id)
326 domains_set = await self.state.get_current_hosts_in_room(room_id)
332327 domains = [
333328 d
334 for d in domains
329 for d in domains_set
335330 if d != self.server_name
336331 and self._federation_shard_config.should_handle(self._instance_name, d)
337332 ]
2323 HttpResponseException,
2424 RequestSendFailed,
2525 )
26 from synapse.api.presence import UserPresenceState
2627 from synapse.events import EventBase
2728 from synapse.federation.units import Edu
2829 from synapse.handlers.presence import format_user_presence_state
2930 from synapse.metrics import sent_transactions_counter
3031 from synapse.metrics.background_process_metrics import run_as_background_process
31 from synapse.storage.presence import UserPresenceState
3232 from synapse.types import ReadReceipt
3333 from synapse.util.retryutils import NotRetryingDestination, get_retry_limiter
3434
5252 )
5353
5454
55 class PerDestinationQueue(object):
55 class PerDestinationQueue:
5656 """
5757 Manages the per-destination transmission queues.
5858
9191 self._destination = destination
9292 self.transmission_loop_running = False
9393
94 # a list of tuples of (pending pdu, order)
95 self._pending_pdus = [] # type: List[Tuple[EventBase, int]]
94 # a list of pending PDUs
95 self._pending_pdus = [] # type: List[EventBase]
9696
9797 # XXX this is never actually used: see
9898 # https://github.com/matrix-org/synapse/issues/7549
131131 + len(self._pending_edus_keyed)
132132 )
133133
134 def send_pdu(self, pdu: EventBase, order: int) -> None:
134 def send_pdu(self, pdu: EventBase) -> None:
135135 """Add a PDU to the queue, and start the transmission loop if necessary
136136
137137 Args:
138138 pdu: pdu to send
139 order
140139 """
141 self._pending_pdus.append((pdu, order))
140 self._pending_pdus.append(pdu)
142141 self.attempt_new_transaction()
143142
144143 def send_presence(self, states: Iterable[UserPresenceState]) -> None:
184183 returns immediately. Otherwise kicks off the process of sending a
185184 transaction in the background.
186185 """
187 # list of (pending_pdu, deferred, order)
186
188187 if self.transmission_loop_running:
189188 # XXX: this can get stuck on by a never-ending
190189 # request at which point pending_pdus just keeps growing.
209208 )
210209
211210 async def _transaction_transmission_loop(self) -> None:
212 pending_pdus = [] # type: List[Tuple[EventBase, int]]
211 pending_pdus = [] # type: List[EventBase]
213212 try:
214213 self.transmission_loop_running = True
215214
336335 (e.retry_last_ts + e.retry_interval) / 1000.0
337336 ),
338337 )
338
339 if e.retry_interval > 60 * 60 * 1000:
340 # we won't retry for another hour!
341 # (this suggests a significant outage)
342 # We drop pending PDUs and EDUs because otherwise they will
343 # rack up indefinitely.
344 # Note that:
345 # - the EDUs that are being dropped here are those that we can
346 # afford to drop (specifically, only typing notifications,
347 # read receipts and presence updates are being dropped here)
348 # - Other EDUs such as to_device messages are queued with a
349 # different mechanism
350 # - this is all volatile state that would be lost if the
351 # federation sender restarted anyway
352
353 # dropping read receipts is a bit sad but should be solved
354 # through another mechanism, because this is all volatile!
355 self._pending_pdus = []
356 self._pending_edus = []
357 self._pending_edus_keyed = {}
358 self._pending_presence = {}
359 self._pending_rrs = {}
339360 except FederationDeniedError as e:
340361 logger.info(e)
341362 except HttpResponseException as e:
350371 "TX [%s] Failed to send transaction: %s", self._destination, e
351372 )
352373
353 for p, _ in pending_pdus:
374 for p in pending_pdus:
354375 logger.info(
355376 "Failed to send event %s to %s", p.event_id, self._destination
356377 )
357378 except Exception:
358379 logger.exception("TX [%s] Failed to send transaction", self._destination)
359 for p, _ in pending_pdus:
380 for p in pending_pdus:
360381 logger.info(
361382 "Failed to send event %s to %s", p.event_id, self._destination
362383 )
1212 # See the License for the specific language governing permissions and
1313 # limitations under the License.
1414 import logging
15 from typing import TYPE_CHECKING, List, Tuple
16
17 from canonicaljson import json
15 from typing import TYPE_CHECKING, List
1816
1917 from synapse.api.errors import HttpResponseException
2018 from synapse.events import EventBase
2725 tags,
2826 whitelisted_homeserver,
2927 )
28 from synapse.util import json_decoder
3029 from synapse.util.metrics import measure_func
3130
3231 if TYPE_CHECKING:
3534 logger = logging.getLogger(__name__)
3635
3736
38 class TransactionManager(object):
37 class TransactionManager:
3938 """Helper class which handles building and sending transactions
4039
4140 shared between PerDestinationQueue objects
5352
5453 @measure_func("_send_new_transaction")
5554 async def send_new_transaction(
56 self,
57 destination: str,
58 pending_pdus: List[Tuple[EventBase, int]],
59 pending_edus: List[Edu],
60 ):
55 self, destination: str, pdus: List[EventBase], edus: List[Edu],
56 ) -> bool:
57 """
58 Args:
59 destination: The destination to send to (e.g. 'example.org')
60 pdus: In-order list of PDUs to send
61 edus: List of EDUs to send
62
63 Returns:
64 True iff the transaction was successful
65 """
6166
6267 # Make a transaction-sending opentracing span. This span follows on from
6368 # all the edus in that transaction. This needs to be done since there is
6772 span_contexts = []
6873 keep_destination = whitelisted_homeserver(destination)
6974
70 for edu in pending_edus:
75 for edu in edus:
7176 context = edu.get_context()
7277 if context:
73 span_contexts.append(extract_text_map(json.loads(context)))
78 span_contexts.append(extract_text_map(json_decoder.decode(context)))
7479 if keep_destination:
7580 edu.strip_context()
7681
7782 with start_active_span_follows_from("send_transaction", span_contexts):
78
79 # Sort based on the order field
80 pending_pdus.sort(key=lambda t: t[1])
81 pdus = [x[0] for x in pending_pdus]
82 edus = pending_edus
83
8483 success = True
8584
8685 logger.debug("TX [%s] _attempt_new_transaction", destination)
2929 logger = logging.getLogger(__name__)
3030
3131
32 class TransportLayerClient(object):
32 class TransportLayerClient:
3333 """Sends federation HTTP requests to other servers"""
3434
3535 def __init__(self, hs):
4444 )
4545 from synapse.server import HomeServer
4646 from synapse.types import ThirdPartyInstanceID, get_domain_from_id
47 from synapse.util.ratelimitutils import FederationRateLimiter
4847 from synapse.util.versionstring import get_version_string
4948
5049 logger = logging.getLogger(__name__)
7170 super(TransportLayerServer, self).__init__(hs, canonical_json=False)
7271
7372 self.authenticator = Authenticator(hs)
74 self.ratelimiter = FederationRateLimiter(
75 self.clock, config=hs.config.rc_federation
76 )
73 self.ratelimiter = hs.get_federation_ratelimiter()
7774
7875 self.register_servlets()
7976
9996 pass
10097
10198
102 class Authenticator(object):
99 class Authenticator:
103100 def __init__(self, hs: HomeServer):
104101 self._clock = hs.get_clock()
105102 self.keyring = hs.get_keyring()
227224 )
228225
229226
230 class BaseFederationServlet(object):
227 class BaseFederationServlet:
231228 """Abstract base class for federation servlet classes.
232229
233230 The servlet object should have a PATH attribute which takes the form of a regexp to
270267 REQUIRE_AUTH = True
271268
272269 PREFIX = FEDERATION_V1_PREFIX # Allows specifying the API version
270
271 RATELIMIT = True # Whether to rate limit requests or not
273272
274273 def __init__(self, handler, authenticator, ratelimiter, server_name):
275274 self.handler = handler
334333 )
335334
336335 with scope:
337 if origin:
336 if origin and self.RATELIMIT:
338337 with ratelimiter.ratelimit(origin) as d:
339338 await d
340339 if request._disconnected:
370369
371370 class FederationSendServlet(BaseFederationServlet):
372371 PATH = "/send/(?P<transaction_id>[^/]*)/?"
372
373 # We ratelimit manually in the handler as we queue up the requests and we
374 # don't want to fill up the ratelimiter with blocked requests.
375 RATELIMIT = False
373376
374377 def __init__(self, handler, server_name, **kwargs):
375378 super(FederationSendServlet, self).__init__(
106106 if "edus" in kwargs and not kwargs["edus"]:
107107 del kwargs["edus"]
108108
109 super(Transaction, self).__init__(
110 transaction_id=transaction_id, pdus=pdus, **kwargs
111 )
109 super().__init__(transaction_id=transaction_id, pdus=pdus, **kwargs)
112110
113111 @staticmethod
114112 def create_new(pdus, **kwargs):
5959 UPDATE_ATTESTATION_TIME_MS = 1 * 24 * 60 * 60 * 1000
6060
6161
62 class GroupAttestationSigning(object):
62 class GroupAttestationSigning:
6363 """Creates and verifies group attestations.
6464 """
6565
123123 )
124124
125125
126 class GroupAttestionRenewer(object):
126 class GroupAttestionRenewer:
127127 """Responsible for sending and receiving attestation updates.
128128 """
129129
3131 # TODO: Flairs
3232
3333
34 class GroupsServerWorkerHandler(object):
34 class GroupsServerWorkerHandler:
3535 def __init__(self, hs):
3636 self.hs = hs
3737 self.store = hs.get_datastore()
1919 from .search import SearchHandler
2020
2121
22 class Handlers(object):
22 class Handlers:
2323
2424 """ Deprecated. A collection of handlers.
2525
2424 logger = logging.getLogger(__name__)
2525
2626
27 class BaseHandler(object):
27 class BaseHandler:
2828 """
2929 Common base class for the event handlers.
3030 """
1313 # limitations under the License.
1414
1515
16 class AccountDataEventSource(object):
16 class AccountDataEventSource:
1717 def __init__(self, hs):
1818 self.store = hs.get_datastore()
1919
2525 from synapse.types import UserID
2626 from synapse.util import stringutils
2727
28 try:
29 from synapse.push.mailer import load_jinja2_templates
30 except ImportError:
31 load_jinja2_templates = None
32
3328 logger = logging.getLogger(__name__)
3429
3530
36 class AccountValidityHandler(object):
31 class AccountValidityHandler:
3732 def __init__(self, hs):
3833 self.hs = hs
3934 self.config = hs.config
4641 if (
4742 self._account_validity.enabled
4843 and self._account_validity.renew_by_email_enabled
49 and load_jinja2_templates
5044 ):
5145 # Don't do email-specific configuration if renewal by email is disabled.
46 self._template_html = self.config.account_validity_template_html
47 self._template_text = self.config.account_validity_template_text
48
5249 try:
5350 app_name = self.hs.config.email_app_name
5451
6360 self._from_string = self.hs.config.email_notif_from
6461
6562 self._raw_from = email.utils.parseaddr(self._from_string)[1]
66
67 self._template_html, self._template_text = load_jinja2_templates(
68 self.config.email_template_dir,
69 [
70 self.config.email_expiry_template_html,
71 self.config.email_expiry_template_text,
72 ],
73 apply_format_ts_filter=True,
74 apply_mxc_to_http_filter=True,
75 public_baseurl=self.config.public_baseurl,
76 )
7763
7864 # Check the renewal emails to send and send them every 30min.
7965 def send_emails():
3333 --------------------------------------------------------------------------------"""
3434
3535
36 class AcmeHandler(object):
36 class AcmeHandler:
3737 def __init__(self, hs):
3838 self.hs = hs
3939 self.reactor = hs.get_reactor()
7777
7878 @attr.s
7979 @implementer(ICertificateStore)
80 class ErsatzStore(object):
80 class ErsatzStore:
8181 """
8282 A store that only stores in memory.
8383 """
196196 return writer.finished()
197197
198198
199 class ExfiltrationWriter(object):
199 class ExfiltrationWriter:
200200 """Interface used to specify how to write exported data.
201201 """
202202
3333 events_processed_counter = Counter("synapse_handlers_appservice_events_processed", "")
3434
3535
36 class ApplicationServicesHandler(object):
36 class ApplicationServicesHandler:
3737 def __init__(self, hs):
3838 self.store = hs.get_datastore()
3939 self.is_mine_id = hs.is_mine_id
4141 from synapse.logging.context import defer_to_thread
4242 from synapse.metrics.background_process_metrics import run_as_background_process
4343 from synapse.module_api import ModuleApi
44 from synapse.push.mailer import load_jinja2_templates
45 from synapse.types import Requester, UserID
44 from synapse.types import JsonDict, Requester, UserID
4645 from synapse.util import stringutils as stringutils
46 from synapse.util.msisdn import phone_number_to_msisdn
4747 from synapse.util.threepids import canonicalise_email
4848
4949 from ._base import BaseHandler
5050
5151 logger = logging.getLogger(__name__)
52
53
54 def convert_client_dict_legacy_fields_to_identifier(
55 submission: JsonDict,
56 ) -> Dict[str, str]:
57 """
58 Convert a legacy-formatted login submission to an identifier dict.
59
60 Legacy login submissions (used in both login and user-interactive authentication)
61 provide user-identifying information at the top-level instead.
62
63 These are now deprecated and replaced with identifiers:
64 https://matrix.org/docs/spec/client_server/r0.6.1#identifier-types
65
66 Args:
67 submission: The client dict to convert
68
69 Returns:
70 The matching identifier dict
71
72 Raises:
73 SynapseError: If the format of the client dict is invalid
74 """
75 identifier = submission.get("identifier", {})
76
77 # Generate an m.id.user identifier if "user" parameter is present
78 user = submission.get("user")
79 if user:
80 identifier = {"type": "m.id.user", "user": user}
81
82 # Generate an m.id.thirdparty identifier if "medium" and "address" parameters are present
83 medium = submission.get("medium")
84 address = submission.get("address")
85 if medium and address:
86 identifier = {
87 "type": "m.id.thirdparty",
88 "medium": medium,
89 "address": address,
90 }
91
92 # We've converted valid, legacy login submissions to an identifier. If the
93 # submission still doesn't have an identifier, it's invalid
94 if not identifier:
95 raise SynapseError(400, "Invalid login submission", Codes.INVALID_PARAM)
96
97 # Ensure the identifier has a type
98 if "type" not in identifier:
99 raise SynapseError(
100 400, "'identifier' dict has no key 'type'", errcode=Codes.MISSING_PARAM,
101 )
102
103 return identifier
104
105
106 def login_id_phone_to_thirdparty(identifier: JsonDict) -> Dict[str, str]:
107 """
108 Convert a phone login identifier type to a generic threepid identifier.
109
110 Args:
111 identifier: Login identifier dict of type 'm.id.phone'
112
113 Returns:
114 An equivalent m.id.thirdparty identifier dict
115 """
116 if "country" not in identifier or (
117 # The specification requires a "phone" field, while Synapse used to require a "number"
118 # field. Accept both for backwards compatibility.
119 "phone" not in identifier
120 and "number" not in identifier
121 ):
122 raise SynapseError(
123 400, "Invalid phone-type identifier", errcode=Codes.INVALID_PARAM
124 )
125
126 # Accept both "phone" and "number" as valid keys in m.id.phone
127 phone_number = identifier.get("phone", identifier["number"])
128
129 # Convert user-provided phone number to a consistent representation
130 msisdn = phone_number_to_msisdn(identifier["country"], phone_number)
131
132 return {
133 "type": "m.id.thirdparty",
134 "medium": "msisdn",
135 "address": msisdn,
136 }
52137
53138
54139 class AuthHandler(BaseHandler):
131216 # after the SSO completes and before redirecting them back to their client.
132217 # It notifies the user they are about to give access to their matrix account
133218 # to the client.
134 self._sso_redirect_confirm_template = load_jinja2_templates(
135 hs.config.sso_template_dir, ["sso_redirect_confirm.html"],
136 )[0]
219 self._sso_redirect_confirm_template = hs.config.sso_redirect_confirm_template
220
137221 # The following template is shown during user interactive authentication
138222 # in the fallback auth scenario. It notifies the user that they are
139223 # authenticating for an operation to occur on their account.
140 self._sso_auth_confirm_template = load_jinja2_templates(
141 hs.config.sso_template_dir, ["sso_auth_confirm.html"],
142 )[0]
224 self._sso_auth_confirm_template = hs.config.sso_auth_confirm_template
225
143226 # The following template is shown after a successful user interactive
144227 # authentication session. It tells the user they can close the window.
145228 self._sso_auth_success_template = hs.config.sso_auth_success_template
229
146230 # The following template is shown during the SSO authentication process if
147231 # the account is deactivated.
148232 self._sso_account_deactivated_template = (
365449 # authentication flow.
366450 await self.store.set_ui_auth_clientdict(sid, clientdict)
367451
452 user_agent = request.requestHeaders.getRawHeaders(b"User-Agent", default=[b""])[
453 0
454 ].decode("ascii", "surrogateescape")
455
456 await self.store.add_user_agent_ip_to_ui_auth_session(
457 session.session_id, user_agent, clientip
458 )
459
368460 if not authdict:
369461 raise InteractiveAuthIncompleteError(
370462 session.session_id, self._auth_dict_for_flows(flows, session.session_id)
11431235
11441236
11451237 @attr.s
1146 class MacaroonGenerator(object):
1238 class MacaroonGenerator:
11471239
11481240 hs = attr.ib()
11491241
3434 """
3535
3636 def __init__(self, hs):
37 self.hs = hs
3738 self._hostname = hs.hostname
3839 self._auth_handler = hs.get_auth_handler()
3940 self._registration_handler = hs.get_registration_handler()
209210
210211 else:
211212 if not registered_user_id:
213 # Pull out the user-agent and IP from the request.
214 user_agent = request.requestHeaders.getRawHeaders(
215 b"User-Agent", default=[b""]
216 )[0].decode("ascii", "surrogateescape")
217 ip_address = self.hs.get_ip_from_request(request)
218
212219 registered_user_id = await self._registration_handler.register_user(
213 localpart=localpart, default_display_name=user_display_name
220 localpart=localpart,
221 default_display_name=user_display_name,
222 user_agent_ips=(user_agent, ip_address),
214223 )
215224
216225 await self._auth_handler.complete_sso_login(
233233 return result
234234
235235 async def on_federation_query_user_devices(self, user_id):
236 stream_id, devices = await self.store.get_devices_with_keys_by_user(user_id)
236 stream_id, devices = await self.store.get_e2e_device_keys_for_federation_query(
237 user_id
238 )
237239 master_key = await self.store.get_e2e_cross_signing_key(user_id, "master")
238240 self_signing_key = await self.store.get_e2e_cross_signing_key(
239241 user_id, "self_signing"
494496 device.update({"last_seen_ts": ip.get("last_seen"), "last_seen_ip": ip.get("ip")})
495497
496498
497 class DeviceListUpdater(object):
499 class DeviceListUpdater:
498500 "Handles incoming device list updates from federation and updates the DB"
499501
500502 def __init__(self, hs, device_handler):
1515 import logging
1616 from typing import Any, Dict
1717
18 from canonicaljson import json
19
2018 from synapse.api.errors import SynapseError
2119 from synapse.logging.context import run_in_background
2220 from synapse.logging.opentracing import (
2624 start_active_span,
2725 )
2826 from synapse.types import UserID, get_domain_from_id
27 from synapse.util import json_encoder
2928 from synapse.util.stringutils import random_string
3029
3130 logger = logging.getLogger(__name__)
3231
3332
34 class DeviceMessageHandler(object):
33 class DeviceMessageHandler:
3534 def __init__(self, hs):
3635 """
3736 Args:
173172 "sender": sender_user_id,
174173 "type": message_type,
175174 "message_id": message_id,
176 "org.matrix.opentracing_context": json.dumps(context),
175 "org.matrix.opentracing_context": json_encoder.encode(context),
177176 }
178177
179178 log_kv({"local_messages": local_messages})
2222 CodeMessageException,
2323 Codes,
2424 NotFoundError,
25 ShadowBanError,
2526 StoreError,
2627 SynapseError,
2728 )
198199
199200 try:
200201 await self._update_canonical_alias(requester, user_id, room_id, room_alias)
202 except ShadowBanError as e:
203 logger.info("Failed to update alias events due to shadow-ban: %s", e)
201204 except AuthError as e:
202205 logger.info("Failed to update alias events: %s", e)
203206
291294 """
292295 Send an updated canonical alias event if the removed alias was set as
293296 the canonical alias or listed in the alt_aliases field.
297
298 Raises:
299 ShadowBanError if the requester has been shadow-banned.
294300 """
295301 alias_event = await self.state.get_current_state(
296302 room_id, EventTypes.CanonicalAlias, ""
1818 from typing import Dict, List, Optional, Tuple
1919
2020 import attr
21 from canonicaljson import encode_canonical_json, json
21 from canonicaljson import encode_canonical_json
2222 from signedjson.key import VerifyKey, decode_verify_key_bytes
2323 from signedjson.sign import SignatureVerifyException, verify_signed_json
2424 from unpaddedbase64 import decode_base64
3434 get_domain_from_id,
3535 get_verify_key_from_cross_signing_key,
3636 )
37 from synapse.util import unwrapFirstError
37 from synapse.util import json_decoder, unwrapFirstError
3838 from synapse.util.async_helpers import Linearizer
3939 from synapse.util.caches.expiringcache import ExpiringCache
4040 from synapse.util.retryutils import NotRetryingDestination
4242 logger = logging.getLogger(__name__)
4343
4444
45 class E2eKeysHandler(object):
45 class E2eKeysHandler:
4646 def __init__(self, hs):
4747 self.store = hs.get_datastore()
4848 self.federation = hs.get_federation_client()
352352 # make sure that each queried user appears in the result dict
353353 result_dict[user_id] = {}
354354
355 results = await self.store.get_e2e_device_keys(local_query)
355 results = await self.store.get_e2e_device_keys_for_cs_api(local_query)
356356
357357 # Build the result structure
358358 for user_id, device_keys in results.items():
403403 for device_id, keys in device_keys.items():
404404 for key_id, json_bytes in keys.items():
405405 json_result.setdefault(user_id, {})[device_id] = {
406 key_id: json.loads(json_bytes)
406 key_id: json_decoder.decode(json_bytes)
407407 }
408408
409409 @trace
733733 # fetch our stored devices. This is used to 1. verify
734734 # signatures on the master key, and 2. to compare with what
735735 # was sent if the device was signed
736 devices = await self.store.get_e2e_device_keys([(user_id, None)])
736 devices = await self.store.get_e2e_device_keys_for_cs_api([(user_id, None)])
737737
738738 if user_id not in devices:
739739 raise NotFoundError("No device keys found")
11851185
11861186
11871187 def _one_time_keys_match(old_key_json, new_key):
1188 old_key = json.loads(old_key_json)
1188 old_key = json_decoder.decode(old_key_json)
11891189
11901190 # if either is a string rather than an object, they must match exactly
11911191 if not isinstance(old_key, dict) or not isinstance(new_key, dict):
12111211 signature = attr.ib()
12121212
12131213
1214 class SigningKeyEduUpdater(object):
1214 class SigningKeyEduUpdater:
12151215 """Handles incoming signing key updates from federation and updates the DB"""
12161216
12171217 def __init__(self, hs, e2e_keys_handler):
2828 logger = logging.getLogger(__name__)
2929
3030
31 class E2eRoomKeysHandler(object):
31 class E2eRoomKeysHandler:
3232 """
3333 Implements an optional realtime backup mechanism for encrypted E2E megolm room keys.
3434 This gives a way for users to store and recover their megolm keys if they lose all
1414
1515 import logging
1616 import random
17 from typing import TYPE_CHECKING, Iterable, List, Optional
1718
1819 from synapse.api.constants import EventTypes, Membership
1920 from synapse.api.errors import AuthError, SynapseError
2021 from synapse.events import EventBase
2122 from synapse.handlers.presence import format_user_presence_state
2223 from synapse.logging.utils import log_function
23 from synapse.types import UserID
24 from synapse.streams.config import PaginationConfig
25 from synapse.types import JsonDict, UserID
2426 from synapse.visibility import filter_events_for_client
2527
2628 from ._base import BaseHandler
29
30 if TYPE_CHECKING:
31 from synapse.server import HomeServer
32
2733
2834 logger = logging.getLogger(__name__)
2935
3036
3137 class EventStreamHandler(BaseHandler):
32 def __init__(self, hs):
38 def __init__(self, hs: "HomeServer"):
3339 super(EventStreamHandler, self).__init__(hs)
34
35 # Count of active streams per user
36 self._streams_per_user = {}
37 # Grace timers per user to delay the "stopped" signal
38 self._stop_timer_per_user = {}
3940
4041 self.distributor = hs.get_distributor()
4142 self.distributor.declare("started_user_eventstream")
5152 @log_function
5253 async def get_stream(
5354 self,
54 auth_user_id,
55 pagin_config,
56 timeout=0,
57 as_client_event=True,
58 affect_presence=True,
59 room_id=None,
60 is_guest=False,
61 ):
55 auth_user_id: str,
56 pagin_config: PaginationConfig,
57 timeout: int = 0,
58 as_client_event: bool = True,
59 affect_presence: bool = True,
60 room_id: Optional[str] = None,
61 is_guest: bool = False,
62 ) -> JsonDict:
6263 """Fetches the events stream for a given user.
6364 """
6465
9798
9899 # When the user joins a new room, or another user joins a currently
99100 # joined room, we need to send down presence for those users.
100 to_add = []
101 to_add = [] # type: List[JsonDict]
101102 for event in events:
102103 if not isinstance(event, EventBase):
103104 continue
109110 # Send down presence for everyone in the room.
110111 users = await self.state.get_current_users_in_room(
111112 event.room_id
112 )
113 ) # type: Iterable[str]
113114 else:
114115 users = [event.state_key]
115116
143144
144145
145146 class EventHandler(BaseHandler):
146 def __init__(self, hs):
147 def __init__(self, hs: "HomeServer"):
147148 super(EventHandler, self).__init__(hs)
148149 self.storage = hs.get_storage()
149150
150 async def get_event(self, user, room_id, event_id):
151 async def get_event(
152 self, user: UserID, room_id: Optional[str], event_id: str
153 ) -> Optional[EventBase]:
151154 """Retrieve a single specified event.
152155
153156 Args:
154 user (synapse.types.UserID): The user requesting the event
155 room_id (str|None): The expected room id. We'll return None if the
157 user: The user requesting the event
158 room_id: The expected room id. We'll return None if the
156159 event's room does not match.
157 event_id (str): The event ID to obtain.
160 event_id: The event ID to obtain.
158161 Returns:
159 dict: An event, or None if there is no event matching this ID.
162 An event, or None if there is no event matching this ID.
160163 Raises:
161164 SynapseError if there was a problem retrieving this event, or
162165 AuthError if the user does not have the rights to inspect this
7171 from synapse.replication.http.membership import ReplicationUserJoinedLeftRoomRestServlet
7272 from synapse.state import StateResolutionStore, resolve_events_with_store
7373 from synapse.storage.databases.main.events_worker import EventRedactBehaviour
74 from synapse.types import JsonDict, StateMap, UserID, get_domain_from_id
74 from synapse.types import (
75 JsonDict,
76 MutableStateMap,
77 StateMap,
78 UserID,
79 get_domain_from_id,
80 )
7581 from synapse.util.async_helpers import Linearizer, concurrently_execute
7682 from synapse.util.distributor import user_joined_room
7783 from synapse.util.retryutils import NotRetryingDestination
95101
96102 event = attr.ib(type=EventBase)
97103 state = attr.ib(type=Optional[Sequence[EventBase]], default=None)
98 auth_events = attr.ib(type=Optional[StateMap[EventBase]], default=None)
104 auth_events = attr.ib(type=Optional[MutableStateMap[EventBase]], default=None)
99105
100106
101107 class FederationHandler(BaseHandler):
433439 if not prevs - seen:
434440 return
435441
436 latest = await self.store.get_latest_event_ids_in_room(room_id)
442 latest_list = await self.store.get_latest_event_ids_in_room(room_id)
437443
438444 # We add the prev events that we have seen to the latest
439445 # list to ensure the remote server doesn't give them to us
440 latest = set(latest)
446 latest = set(latest_list)
441447 latest |= seen
442448
443449 logger.info(
774780 # keys across all devices.
775781 current_keys = [
776782 key
777 for device in cached_devices
783 for device in cached_devices.values()
778784 for key in device.get("keys", {}).get("keys", {}).values()
779785 ]
780786
936942
937943 return events
938944
939 async def maybe_backfill(self, room_id, current_depth):
945 async def maybe_backfill(
946 self, room_id: str, current_depth: int, limit: int
947 ) -> bool:
940948 """Checks the database to see if we should backfill before paginating,
941949 and if so do.
950
951 Args:
952 room_id
953 current_depth: The depth from which we're paginating from. This is
954 used to decide if we should backfill and what extremities to
955 use.
956 limit: The number of events that the pagination request will
957 return. This is used as part of the heuristic to decide if we
958 should back paginate.
942959 """
943960 extremities = await self.store.get_oldest_events_with_depth_in_room(room_id)
944961
945962 if not extremities:
946963 logger.debug("Not backfilling as no extremeties found.")
947 return
964 return False
948965
949966 # We only want to paginate if we can actually see the events we'll get,
950967 # as otherwise we'll just spend a lot of resources to get redacted
9971014 sorted_extremeties_tuple = sorted(extremities.items(), key=lambda e: -int(e[1]))
9981015 max_depth = sorted_extremeties_tuple[0][1]
9991016
1017 # If we're approaching an extremity we trigger a backfill, otherwise we
1018 # no-op.
1019 #
1020 # We chose twice the limit here as then clients paginating backwards
1021 # will send pagination requests that trigger backfill at least twice
1022 # using the most recent extremity before it gets removed (see below). We
1023 # chose more than one times the limit in case of failure, but choosing a
1024 # much larger factor will result in triggering a backfill request much
1025 # earlier than necessary.
1026 if current_depth - 2 * limit > max_depth:
1027 logger.debug(
1028 "Not backfilling as we don't need to. %d < %d - 2 * %d",
1029 max_depth,
1030 current_depth,
1031 limit,
1032 )
1033 return False
1034
1035 logger.debug(
1036 "room_id: %s, backfill: current_depth: %s, max_depth: %s, extrems: %s",
1037 room_id,
1038 current_depth,
1039 max_depth,
1040 sorted_extremeties_tuple,
1041 )
1042
1043 # We ignore extremities that have a greater depth than our current depth
1044 # as:
1045 # 1. we don't really care about getting events that have happened
1046 # before our current position; and
1047 # 2. we have likely previously tried and failed to backfill from that
1048 # extremity, so to avoid getting "stuck" requesting the same
1049 # backfill repeatedly we drop those extremities.
1050 filtered_sorted_extremeties_tuple = [
1051 t for t in sorted_extremeties_tuple if int(t[1]) <= current_depth
1052 ]
1053
1054 # However, we need to check that the filtered extremities are non-empty.
1055 # If they are empty then either we can a) bail or b) still attempt to
1056 # backill. We opt to try backfilling anyway just in case we do get
1057 # relevant events.
1058 if filtered_sorted_extremeties_tuple:
1059 sorted_extremeties_tuple = filtered_sorted_extremeties_tuple
1060
10001061 # We don't want to specify too many extremities as it causes the backfill
10011062 # request URI to be too long.
10021063 extremities = dict(sorted_extremeties_tuple[:5])
1003
1004 if current_depth > max_depth:
1005 logger.debug(
1006 "Not backfilling as we don't need to. %d < %d", max_depth, current_depth
1007 )
1008 return
10091064
10101065 # Now we need to decide which hosts to hit first.
10111066
17761831 """Returns the state at the event. i.e. not including said event.
17771832 """
17781833
1779 event = await self.store.get_event(
1780 event_id, allow_none=False, check_room_id=room_id
1781 )
1834 event = await self.store.get_event(event_id, check_room_id=room_id)
17821835
17831836 state_groups = await self.state_store.get_state_groups(room_id, [event_id])
17841837
18041857 async def get_state_ids_for_pdu(self, room_id: str, event_id: str) -> List[str]:
18051858 """Returns the state at the event. i.e. not including said event.
18061859 """
1807 event = await self.store.get_event(
1808 event_id, allow_none=False, check_room_id=room_id
1809 )
1860 event = await self.store.get_event(event_id, check_room_id=room_id)
18101861
18111862 state_groups = await self.state_store.get_state_groups_ids(room_id, [event_id])
18121863
18761927 else:
18771928 return None
18781929
1879 def get_min_depth_for_context(self, context):
1880 return self.store.get_min_depth(context)
1930 async def get_min_depth_for_context(self, context):
1931 return await self.store.get_min_depth(context)
18811932
18821933 async def _handle_new_event(
18831934 self, origin, event, state=None, auth_events=None, backfilled=False
20562107 origin: str,
20572108 event: EventBase,
20582109 state: Optional[Iterable[EventBase]],
2059 auth_events: Optional[StateMap[EventBase]],
2110 auth_events: Optional[MutableStateMap[EventBase]],
20602111 backfilled: bool,
20612112 ) -> EventContext:
20622113 context = await self.state_handler.compute_event_context(event, old_state=state)
21062157 if backfilled or event.internal_metadata.is_outlier():
21072158 return
21082159
2109 extrem_ids = await self.store.get_latest_event_ids_in_room(event.room_id)
2110 extrem_ids = set(extrem_ids)
2160 extrem_ids_list = await self.store.get_latest_event_ids_in_room(event.room_id)
2161 extrem_ids = set(extrem_ids_list)
21112162 prev_event_ids = set(event.prev_event_ids())
21122163
21132164 if extrem_ids == prev_event_ids:
21372188 )
21382189 state_sets = list(state_sets.values())
21392190 state_sets.append(state)
2140 current_state_ids = await self.state_handler.resolve_events(
2191 current_states = await self.state_handler.resolve_events(
21412192 room_version, state_sets, event
21422193 )
2143 current_state_ids = {k: e.event_id for k, e in current_state_ids.items()}
2194 current_state_ids = {
2195 k: e.event_id for k, e in current_states.items()
2196 } # type: StateMap[str]
21442197 else:
21452198 current_state_ids = await self.state_handler.get_current_state_ids(
21462199 event.room_id, latest_event_ids=extrem_ids
21522205
21532206 # Now check if event pass auth against said current state
21542207 auth_types = auth_types_for_event(event)
2155 current_state_ids = [e for k, e in current_state_ids.items() if k in auth_types]
2156
2157 current_auth_events = await self.store.get_events(current_state_ids)
2208 current_state_ids_list = [
2209 e for k, e in current_state_ids.items() if k in auth_types
2210 ]
2211
2212 auth_events_map = await self.store.get_events(current_state_ids_list)
21582213 current_auth_events = {
2159 (e.type, e.state_key): e for e in current_auth_events.values()
2214 (e.type, e.state_key): e for e in auth_events_map.values()
21602215 }
21612216
21622217 try:
21722227 if not in_room:
21732228 raise AuthError(403, "Host not in room.")
21742229
2175 event = await self.store.get_event(
2176 event_id, allow_none=False, check_room_id=room_id
2177 )
2230 event = await self.store.get_event(event_id, check_room_id=room_id)
21782231
21792232 # Just go through and process each event in `remote_auth_chain`. We
21802233 # don't want to fall into the trap of `missing` being wrong.
22262279 origin: str,
22272280 event: EventBase,
22282281 context: EventContext,
2229 auth_events: StateMap[EventBase],
2282 auth_events: MutableStateMap[EventBase],
22302283 ) -> EventContext:
22312284 """
22322285
22772330 origin: str,
22782331 event: EventBase,
22792332 context: EventContext,
2280 auth_events: StateMap[EventBase],
2333 auth_events: MutableStateMap[EventBase],
22812334 ) -> EventContext:
22822335 """Helper for do_auth. See there for docs.
22832336
5151 return f
5252
5353
54 class GroupsLocalWorkerHandler(object):
54 class GroupsLocalWorkerHandler:
5555 def __init__(self, hs):
5656 self.hs = hs
5757 self.store = hs.get_datastore()
2020 import urllib.parse
2121 from typing import Awaitable, Callable, Dict, List, Optional, Tuple
2222
23 from canonicaljson import json
24
2523 from twisted.internet.error import TimeoutError
2624
2725 from synapse.api.errors import (
3331 from synapse.config.emailconfig import ThreepidBehaviour
3432 from synapse.http.client import SimpleHttpClient
3533 from synapse.types import JsonDict, Requester
34 from synapse.util import json_decoder
3635 from synapse.util.hash import sha256_and_url_safe_base64
3736 from synapse.util.stringutils import assert_valid_client_secret, random_string
3837
176175 except TimeoutError:
177176 raise SynapseError(500, "Timed out contacting identity server")
178177 except CodeMessageException as e:
179 data = json.loads(e.msg) # XXX WAT?
178 data = json_decoder.decode(e.msg) # XXX WAT?
180179 return data
181180
182181 logger.info("Got 404 when POSTing JSON %s, falling back to v1 URL", bind_url)
1313 # limitations under the License.
1414
1515 import logging
16 from typing import TYPE_CHECKING
1617
1718 from twisted.internet import defer
1819
2122 from synapse.events.validator import EventValidator
2223 from synapse.handlers.presence import format_user_presence_state
2324 from synapse.logging.context import make_deferred_yieldable, run_in_background
25 from synapse.storage.roommember import RoomsForUser
2426 from synapse.streams.config import PaginationConfig
25 from synapse.types import StreamToken, UserID
27 from synapse.types import JsonDict, Requester, StreamToken, UserID
2628 from synapse.util import unwrapFirstError
2729 from synapse.util.async_helpers import concurrently_execute
2830 from synapse.util.caches.response_cache import ResponseCache
3032
3133 from ._base import BaseHandler
3234
35 if TYPE_CHECKING:
36 from synapse.server import HomeServer
37
38
3339 logger = logging.getLogger(__name__)
3440
3541
3642 class InitialSyncHandler(BaseHandler):
37 def __init__(self, hs):
43 def __init__(self, hs: "HomeServer"):
3844 super(InitialSyncHandler, self).__init__(hs)
3945 self.hs = hs
4046 self.state = hs.get_state_handler()
4753
4854 def snapshot_all_rooms(
4955 self,
50 user_id=None,
51 pagin_config=None,
52 as_client_event=True,
53 include_archived=False,
54 ):
56 user_id: str,
57 pagin_config: PaginationConfig,
58 as_client_event: bool = True,
59 include_archived: bool = False,
60 ) -> JsonDict:
5561 """Retrieve a snapshot of all rooms the user is invited or has joined.
5662
5763 This snapshot may include messages for all rooms where the user is
5864 joined, depending on the pagination config.
5965
6066 Args:
61 user_id (str): The ID of the user making the request.
62 pagin_config (synapse.api.streams.PaginationConfig): The pagination
63 config used to determine how many messages *PER ROOM* to return.
64 as_client_event (bool): True to get events in client-server format.
65 include_archived (bool): True to get rooms that the user has left
67 user_id: The ID of the user making the request.
68 pagin_config: The pagination config used to determine how many
69 messages *PER ROOM* to return.
70 as_client_event: True to get events in client-server format.
71 include_archived: True to get rooms that the user has left
6672 Returns:
67 A list of dicts with "room_id" and "membership" keys for all rooms
68 the user is currently invited or joined in on. Rooms where the user
69 is joined on, may return a "messages" key with messages, depending
70 on the specified PaginationConfig.
73 A JsonDict with the same format as the response to `/intialSync`
74 API
7175 """
7276 key = (
7377 user_id,
9094
9195 async def _snapshot_all_rooms(
9296 self,
93 user_id=None,
94 pagin_config=None,
95 as_client_event=True,
96 include_archived=False,
97 ):
97 user_id: str,
98 pagin_config: PaginationConfig,
99 as_client_event: bool = True,
100 include_archived: bool = False,
101 ) -> JsonDict:
98102
99103 memberships = [Membership.INVITE, Membership.JOIN]
100104 if include_archived:
133137 if limit is None:
134138 limit = 10
135139
136 async def handle_room(event):
140 async def handle_room(event: RoomsForUser):
137141 d = {
138142 "room_id": event.room_id,
139143 "membership": event.membership,
250254
251255 return ret
252256
253 async def room_initial_sync(self, requester, room_id, pagin_config=None):
257 async def room_initial_sync(
258 self, requester: Requester, room_id: str, pagin_config: PaginationConfig
259 ) -> JsonDict:
254260 """Capture the a snapshot of a room. If user is currently a member of
255261 the room this will be what is currently in the room. If the user left
256262 the room this will be what was in the room when they left.
257263
258264 Args:
259 requester(Requester): The user to get a snapshot for.
260 room_id(str): The room to get a snapshot of.
261 pagin_config(synapse.streams.config.PaginationConfig):
262 The pagination config used to determine how many messages to
263 return.
265 requester: The user to get a snapshot for.
266 room_id: The room to get a snapshot of.
267 pagin_config: The pagination config used to determine how many
268 messages to return.
264269 Raises:
265270 AuthError if the user wasn't in the room.
266271 Returns:
304309 return result
305310
306311 async def _room_initial_sync_parted(
307 self, user_id, room_id, pagin_config, membership, member_event_id, is_peeking
308 ):
312 self,
313 user_id: str,
314 room_id: str,
315 pagin_config: PaginationConfig,
316 membership: Membership,
317 member_event_id: str,
318 is_peeking: bool,
319 ) -> JsonDict:
309320 room_state = await self.state_store.get_state_for_events([member_event_id])
310321
311322 room_state = room_state[member_event_id]
349360 }
350361
351362 async def _room_initial_sync_joined(
352 self, user_id, room_id, pagin_config, membership, is_peeking
353 ):
363 self,
364 user_id: str,
365 room_id: str,
366 pagin_config: PaginationConfig,
367 membership: Membership,
368 is_peeking: bool,
369 ) -> JsonDict:
354370 current_state = await self.state.get_current_state(room_id=room_id)
355371
356372 # TODO: These concurrently
1414 # See the License for the specific language governing permissions and
1515 # limitations under the License.
1616 import logging
17 import random
1718 from typing import TYPE_CHECKING, Dict, List, Optional, Tuple
1819
19 from canonicaljson import encode_canonical_json, json
20 from canonicaljson import encode_canonical_json
2021
2122 from twisted.internet.interfaces import IDelayedCall
2223
3334 Codes,
3435 ConsentNotGivenError,
3536 NotFoundError,
37 ShadowBanError,
3638 SynapseError,
3739 )
3840 from synapse.api.room_versions import KNOWN_ROOM_VERSIONS, RoomVersions
4648 from synapse.replication.http.send_event import ReplicationSendEventRestServlet
4749 from synapse.storage.databases.main.events_worker import EventRedactBehaviour
4850 from synapse.storage.state import StateFilter
49 from synapse.types import (
50 Collection,
51 Requester,
52 RoomAlias,
53 StreamToken,
54 UserID,
55 create_requester,
56 )
51 from synapse.types import Requester, RoomAlias, StreamToken, UserID, create_requester
52 from synapse.util import json_decoder
5753 from synapse.util.async_helpers import Linearizer
5854 from synapse.util.frozenutils import frozendict_json_encoder
5955 from synapse.util.metrics import measure_func
6763 logger = logging.getLogger(__name__)
6864
6965
70 class MessageHandler(object):
66 class MessageHandler:
7167 """Contains some read only APIs to get state about a room
7268 """
7369
9187 )
9288
9389 async def get_room_data(
94 self,
95 user_id: str,
96 room_id: str,
97 event_type: str,
98 state_key: str,
99 is_guest: bool,
90 self, user_id: str, room_id: str, event_type: str, state_key: str,
10091 ) -> dict:
10192 """ Get data from a room.
10293
10596 room_id
10697 event_type
10798 state_key
108 is_guest
10999 Returns:
110100 The path data content.
111101 Raises:
112 SynapseError if something went wrong.
102 SynapseError or AuthError if the user is not in the room
113103 """
114104 (
115105 membership,
126116 [membership_event_id], StateFilter.from_types([key])
127117 )
128118 data = room_state[membership_event_id].get(key)
119 else:
120 # check_user_in_room_or_world_readable, if it doesn't raise an AuthError, should
121 # only ever return a Membership.JOIN/LEAVE object
122 #
123 # Safeguard in case it returned something else
124 logger.error(
125 "Attempted to retrieve data from a room for a user that has never been in it. "
126 "This should not have happened."
127 )
128 raise SynapseError(403, "User not in room", errcode=Codes.FORBIDDEN)
129129
130130 return data
131131
360360 _DUMMY_EVENT_ROOM_EXCLUSION_EXPIRY = 7 * 24 * 60 * 60 * 1000
361361
362362
363 class EventCreationHandler(object):
363 class EventCreationHandler:
364364 def __init__(self, hs: "HomeServer"):
365365 self.hs = hs
366366 self.auth = hs.get_auth()
438438 event_dict: dict,
439439 token_id: Optional[str] = None,
440440 txn_id: Optional[str] = None,
441 prev_event_ids: Optional[Collection[str]] = None,
441 prev_event_ids: Optional[List[str]] = None,
442442 require_consent: bool = True,
443443 ) -> Tuple[EventBase, EventContext]:
444444 """
643643 event: EventBase,
644644 context: EventContext,
645645 ratelimit: bool = True,
646 ignore_shadow_ban: bool = False,
646647 ) -> int:
647648 """
648649 Persists and notifies local clients and federation of an event.
649650
650651 Args:
651 requester
652 event the event to send.
653 context: the context of the event.
652 requester: The requester sending the event.
653 event: The event to send.
654 context: The context of the event.
654655 ratelimit: Whether to rate limit this send.
656 ignore_shadow_ban: True if shadow-banned users should be allowed to
657 send this event.
655658
656659 Return:
657660 The stream_id of the persisted event.
661
662 Raises:
663 ShadowBanError if the requester has been shadow-banned.
658664 """
659665 if event.type == EventTypes.Member:
660666 raise SynapseError(
661667 500, "Tried to send member event through non-member codepath"
662668 )
663669
670 if not ignore_shadow_ban and requester.shadow_banned:
671 # We randomly sleep a bit just to annoy the requester.
672 await self.clock.sleep(random.randint(1, 10))
673 raise ShadowBanError()
674
664675 user = UserID.from_string(event.sender)
665676
666677 assert self.hs.is_mine(user), "User must be our own: %s" % (user,)
667678
668679 if event.is_state():
669 prev_state = await self.deduplicate_state_event(event, context)
670 if prev_state is not None:
680 prev_event = await self.deduplicate_state_event(event, context)
681 if prev_event is not None:
671682 logger.info(
672683 "Not bothering to persist state event %s duplicated by %s",
673684 event.event_id,
674 prev_state.event_id,
675 )
676 return prev_state
685 prev_event.event_id,
686 )
687 return await self.store.get_stream_id_for_event(prev_event.event_id)
677688
678689 return await self.handle_new_client_event(
679690 requester=requester, event=event, context=context, ratelimit=ratelimit
681692
682693 async def deduplicate_state_event(
683694 self, event: EventBase, context: EventContext
684 ) -> None:
695 ) -> Optional[EventBase]:
685696 """
686697 Checks whether event is in the latest resolved state in context.
687698
688 If so, returns the version of the event in context.
689 Otherwise, returns None.
699 Args:
700 event: The event to check for duplication.
701 context: The event context.
702
703 Returns:
704 The previous verion of the event is returned, if it is found in the
705 event context. Otherwise, None is returned.
690706 """
691707 prev_state_ids = await context.get_prev_state_ids()
692708 prev_event_id = prev_state_ids.get((event.type, event.state_key))
693709 if not prev_event_id:
694 return
710 return None
695711 prev_event = await self.store.get_event(prev_event_id, allow_none=True)
696712 if not prev_event:
697 return
713 return None
698714
699715 if prev_event and event.user_id == prev_event.user_id:
700716 prev_content = encode_canonical_json(prev_event.content)
701717 next_content = encode_canonical_json(event.content)
702718 if prev_content == next_content:
703719 return prev_event
704 return
720 return None
705721
706722 async def create_and_send_nonmember_event(
707723 self,
709725 event_dict: dict,
710726 ratelimit: bool = True,
711727 txn_id: Optional[str] = None,
728 ignore_shadow_ban: bool = False,
712729 ) -> Tuple[EventBase, int]:
713730 """
714731 Creates an event, then sends it.
715732
716733 See self.create_event and self.send_nonmember_event.
717 """
734
735 Args:
736 requester: The requester sending the event.
737 event_dict: An entire event.
738 ratelimit: Whether to rate limit this send.
739 txn_id: The transaction ID.
740 ignore_shadow_ban: True if shadow-banned users should be allowed to
741 send this event.
742
743 Raises:
744 ShadowBanError if the requester has been shadow-banned.
745 """
746 if not ignore_shadow_ban and requester.shadow_banned:
747 # We randomly sleep a bit just to annoy the requester.
748 await self.clock.sleep(random.randint(1, 10))
749 raise ShadowBanError()
718750
719751 # We limit the number of concurrent event sends in a room so that we
720752 # don't fork the DAG too much. If we don't limit then we can end up in
733765 raise SynapseError(403, spam_error, Codes.FORBIDDEN)
734766
735767 stream_id = await self.send_nonmember_event(
736 requester, event, context, ratelimit=ratelimit
768 requester,
769 event,
770 context,
771 ratelimit=ratelimit,
772 ignore_shadow_ban=ignore_shadow_ban,
737773 )
738774 return event, stream_id
739775
742778 self,
743779 builder: EventBuilder,
744780 requester: Optional[Requester] = None,
745 prev_event_ids: Optional[Collection[str]] = None,
781 prev_event_ids: Optional[List[str]] = None,
746782 ) -> Tuple[EventBase, EventContext]:
747783 """Create a new event for a local client
748784
858894 # Ensure that we can round trip before trying to persist in db
859895 try:
860896 dump = frozendict_json_encoder.encode(event.content)
861 json.loads(dump)
897 json_decoder.decode(dump)
862898 except Exception:
863899 logger.exception("Failed to encode content: %r", event.content)
864900 raise
890926 except Exception:
891927 # Ensure that we actually remove the entries in the push actions
892928 # staging area, if we calculated them.
893 run_in_background(
894 self.store.remove_push_actions_from_staging, event.event_id
895 )
929 await self.store.remove_push_actions_from_staging(event.event_id)
896930 raise
897931
898932 async def _validate_canonical_alias(
956990 allow_none=True,
957991 )
958992
959 is_admin_redaction = (
993 is_admin_redaction = bool(
960994 original_event and event.sender != original_event.sender
961995 )
962996
10761110 auth_events_ids = self.auth.compute_auth_events(
10771111 event, prev_state_ids, for_verification=True
10781112 )
1079 auth_events = await self.store.get_events(auth_events_ids)
1080 auth_events = {(e.type, e.state_key): e for e in auth_events.values()}
1113 auth_events_map = await self.store.get_events(auth_events_ids)
1114 auth_events = {(e.type, e.state_key): e for e in auth_events_map.values()}
10811115
10821116 room_version = await self.store.get_room_version_id(event.room_id)
10831117 room_version_obj = KNOWN_ROOM_VERSIONS[room_version]
11751209
11761210 event.internal_metadata.proactively_send = False
11771211
1212 # Since this is a dummy-event it is OK if it is sent by a
1213 # shadow-banned user.
11781214 await self.send_nonmember_event(
1179 requester, event, context, ratelimit=False
1215 requester,
1216 event,
1217 context,
1218 ratelimit=False,
1219 ignore_shadow_ban=True,
11801220 )
11811221 dummy_event_sent = True
11821222 break
1111 # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
1212 # See the License for the specific language governing permissions and
1313 # limitations under the License.
14 import json
1514 import logging
1615 from typing import TYPE_CHECKING, Dict, Generic, List, Optional, Tuple, TypeVar
1716 from urllib.parse import urlencode
3736 from synapse.http.server import respond_with_html
3837 from synapse.http.site import SynapseRequest
3938 from synapse.logging.context import make_deferred_yieldable
40 from synapse.push.mailer import load_jinja2_templates
4139 from synapse.types import UserID, map_username_to_mxid_localpart
40 from synapse.util import json_decoder
4241
4342 if TYPE_CHECKING:
4443 from synapse.server import HomeServer
9392 """
9493
9594 def __init__(self, hs: "HomeServer"):
95 self.hs = hs
9696 self._callback_url = hs.config.oidc_callback_url # type: str
9797 self._scopes = hs.config.oidc_scopes # type: List[str]
9898 self._client_auth = ClientAuth(
122122 self._hostname = hs.hostname # type: str
123123 self._server_name = hs.config.server_name # type: str
124124 self._macaroon_secret_key = hs.config.macaroon_secret_key
125 self._error_template = load_jinja2_templates(
126 hs.config.sso_template_dir, ["sso_error.html"]
127 )[0]
125 self._error_template = hs.config.sso_error_template
128126
129127 # identifier for the external_ids table
130128 self._auth_provider_id = "oidc"
369367 # and check for an error field. If not, we respond with a generic
370368 # error message.
371369 try:
372 resp = json.loads(resp_body.decode("utf-8"))
370 resp = json_decoder.decode(resp_body.decode("utf-8"))
373371 error = resp["error"]
374372 description = resp.get("error_description", error)
375373 except (ValueError, KeyError):
386384
387385 # Since it is a not a 5xx code, body should be a valid JSON. It will
388386 # raise if not.
389 resp = json.loads(resp_body.decode("utf-8"))
387 resp = json_decoder.decode(resp_body.decode("utf-8"))
390388
391389 if "error" in resp:
392390 error = resp["error"]
691689 self._render_error(request, "invalid_token", str(e))
692690 return
693691
692 # Pull out the user-agent and IP from the request.
693 user_agent = request.requestHeaders.getRawHeaders(b"User-Agent", default=[b""])[
694 0
695 ].decode("ascii", "surrogateescape")
696 ip_address = self.hs.get_ip_from_request(request)
697
694698 # Call the mapper to register/login the user
695699 try:
696 user_id = await self._map_userinfo_to_user(userinfo, token)
700 user_id = await self._map_userinfo_to_user(
701 userinfo, token, user_agent, ip_address
702 )
697703 except MappingException as e:
698704 logger.exception("Could not map user")
699705 self._render_error(request, "mapping_error", str(e))
830836 now = self._clock.time_msec()
831837 return now < expiry
832838
833 async def _map_userinfo_to_user(self, userinfo: UserInfo, token: Token) -> str:
839 async def _map_userinfo_to_user(
840 self, userinfo: UserInfo, token: Token, user_agent: str, ip_address: str
841 ) -> str:
834842 """Maps a UserInfo object to a mxid.
835843
836844 UserInfo should have a claim that uniquely identifies users. This claim
845853 Args:
846854 userinfo: an object representing the user
847855 token: a dict with the tokens obtained from the provider
856 user_agent: The user agent of the client making the request.
857 ip_address: The IP address of the client making the request.
848858
849859 Raises:
850860 MappingException: if there was an error while mapping some properties
858868 raise MappingException(
859869 "Failed to extract subject from OIDC response: %s" % (e,)
860870 )
871 # Some OIDC providers use integer IDs, but Synapse expects external IDs
872 # to be strings.
873 remote_user_id = str(remote_user_id)
861874
862875 logger.info(
863876 "Looking for existing mapping for user %s:%s",
901914 # It's the first time this user is logging in and the mapped mxid was
902915 # not taken, register the user
903916 registered_user_id = await self._registration_handler.register_user(
904 localpart=localpart, default_display_name=attributes["display_name"],
917 localpart=localpart,
918 default_display_name=attributes["display_name"],
919 user_agent_ips=(user_agent, ip_address),
905920 )
906921
907922 await self._datastore.record_user_external_id(
1313 # See the License for the specific language governing permissions and
1414 # limitations under the License.
1515 import logging
16 from typing import TYPE_CHECKING, Any, Dict, Optional, Set
1617
1718 from twisted.python.failure import Failure
1819
1920 from synapse.api.constants import EventTypes, Membership
2021 from synapse.api.errors import SynapseError
22 from synapse.api.filtering import Filter
2123 from synapse.logging.context import run_in_background
2224 from synapse.metrics.background_process_metrics import run_as_background_process
2325 from synapse.storage.state import StateFilter
24 from synapse.types import RoomStreamToken
26 from synapse.streams.config import PaginationConfig
27 from synapse.types import Requester, RoomStreamToken
2528 from synapse.util.async_helpers import ReadWriteLock
2629 from synapse.util.stringutils import random_string
2730 from synapse.visibility import filter_events_for_client
2831
32 if TYPE_CHECKING:
33 from synapse.server import HomeServer
34
35
2936 logger = logging.getLogger(__name__)
3037
3138
32 class PurgeStatus(object):
39 class PurgeStatus:
3340 """Object tracking the status of a purge request
3441
3542 This class contains information on the progress of a purge request, for
5764 return {"status": PurgeStatus.STATUS_TEXT[self.status]}
5865
5966
60 class PaginationHandler(object):
67 class PaginationHandler:
6168 """Handles pagination and purge history requests.
6269
6370 These are in the same handler due to the fact we need to block clients
6471 paginating during a purge.
6572 """
6673
67 def __init__(self, hs):
74 def __init__(self, hs: "HomeServer"):
6875 self.hs = hs
6976 self.auth = hs.get_auth()
7077 self.store = hs.get_datastore()
7481 self._server_name = hs.hostname
7582
7683 self.pagination_lock = ReadWriteLock()
77 self._purges_in_progress_by_room = set()
84 self._purges_in_progress_by_room = set() # type: Set[str]
7885 # map from purge id to PurgeStatus
79 self._purges_by_id = {}
86 self._purges_by_id = {} # type: Dict[str, PurgeStatus]
8087 self._event_serializer = hs.get_event_client_serializer()
8188
8289 self._retention_default_max_lifetime = hs.config.retention_default_max_lifetime
90
91 self._retention_allowed_lifetime_min = hs.config.retention_allowed_lifetime_min
92 self._retention_allowed_lifetime_max = hs.config.retention_allowed_lifetime_max
8393
8494 if hs.config.retention_enabled:
8595 # Run the purge jobs described in the configuration file.
95105 job["longest_max_lifetime"],
96106 )
97107
98 async def purge_history_for_rooms_in_range(self, min_ms, max_ms):
108 async def purge_history_for_rooms_in_range(
109 self, min_ms: Optional[int], max_ms: Optional[int]
110 ):
99111 """Purge outdated events from rooms within the given retention range.
100112
101113 If a default retention policy is defined in the server's configuration and its
103115 retention policy.
104116
105117 Args:
106 min_ms (int|None): Duration in milliseconds that define the lower limit of
118 min_ms: Duration in milliseconds that define the lower limit of
107119 the range to handle (exclusive). If None, it means that the range has no
108120 lower limit.
109 max_ms (int|None): Duration in milliseconds that define the upper limit of
121 max_ms: Duration in milliseconds that define the upper limit of
110122 the range to handle (inclusive). If None, it means that the range has no
111123 upper limit.
112124 """
113 # We want the storage layer to to include rooms with no retention policy in its
125 # We want the storage layer to include rooms with no retention policy in its
114126 # return value only if a default retention policy is defined in the server's
115127 # configuration and that policy's 'max_lifetime' is either lower (or equal) than
116128 # max_ms or higher than min_ms (or both).
151163 )
152164 continue
153165
154 max_lifetime = retention_policy["max_lifetime"]
155
156 if max_lifetime is None:
157 # If max_lifetime is None, it means that include_null equals True,
158 # therefore we can safely assume that there is a default policy defined
159 # in the server's configuration.
160 max_lifetime = self._retention_default_max_lifetime
166 # If max_lifetime is None, it means that the room has no retention policy.
167 # Given we only retrieve such rooms when there's a default retention policy
168 # defined in the server's configuration, we can safely assume that's the
169 # case and use it for this room.
170 max_lifetime = (
171 retention_policy["max_lifetime"] or self._retention_default_max_lifetime
172 )
173
174 # Cap the effective max_lifetime to be within the range allowed in the
175 # config.
176 # We do this in two steps:
177 # 1. Make sure it's higher or equal to the minimum allowed value, and if
178 # it's not replace it with that value. This is because the server
179 # operator can be required to not delete information before a given
180 # time, e.g. to comply with freedom of information laws.
181 # 2. Make sure the resulting value is lower or equal to the maximum allowed
182 # value, and if it's not replace it with that value. This is because the
183 # server operator can be required to delete any data after a specific
184 # amount of time.
185 if self._retention_allowed_lifetime_min is not None:
186 max_lifetime = max(self._retention_allowed_lifetime_min, max_lifetime)
187
188 if self._retention_allowed_lifetime_max is not None:
189 max_lifetime = min(max_lifetime, self._retention_allowed_lifetime_max)
190
191 logger.debug("[purge] max_lifetime for room %s: %s", room_id, max_lifetime)
161192
162193 # Figure out what token we should start purging at.
163194 ts = self.clock.time_msec() - max_lifetime
194225 "_purge_history", self._purge_history, purge_id, room_id, token, True,
195226 )
196227
197 def start_purge_history(self, room_id, token, delete_local_events=False):
228 def start_purge_history(
229 self, room_id: str, token: str, delete_local_events: bool = False
230 ) -> str:
198231 """Start off a history purge on a room.
199232
200233 Args:
201 room_id (str): The room to purge from
202
203 token (str): topological token to delete events before
204 delete_local_events (bool): True to delete local events as well as
234 room_id: The room to purge from
235 token: topological token to delete events before
236 delete_local_events: True to delete local events as well as
205237 remote ones
206238
207239 Returns:
208 str: unique ID for this purge transaction.
240 unique ID for this purge transaction.
209241 """
210242 if room_id in self._purges_in_progress_by_room:
211243 raise SynapseError(
224256 )
225257 return purge_id
226258
227 async def _purge_history(self, purge_id, room_id, token, delete_local_events):
259 async def _purge_history(
260 self, purge_id: str, room_id: str, token: str, delete_local_events: bool
261 ) -> None:
228262 """Carry out a history purge on a room.
229263
230264 Args:
231 purge_id (str): The id for this purge
232 room_id (str): The room to purge from
233 token (str): topological token to delete events before
234 delete_local_events (bool): True to delete local events as well as
235 remote ones
265 purge_id: The id for this purge
266 room_id: The room to purge from
267 token: topological token to delete events before
268 delete_local_events: True to delete local events as well as remote ones
236269 """
237270 self._purges_in_progress_by_room.add(room_id)
238271 try:
257290
258291 self.hs.get_reactor().callLater(24 * 3600, clear_purge)
259292
260 def get_purge_status(self, purge_id):
293 def get_purge_status(self, purge_id: str) -> Optional[PurgeStatus]:
261294 """Get the current status of an active purge
262295
263296 Args:
264 purge_id (str): purge_id returned by start_purge_history
265
266 Returns:
267 PurgeStatus|None
297 purge_id: purge_id returned by start_purge_history
268298 """
269299 return self._purges_by_id.get(purge_id)
270300
271 async def purge_room(self, room_id):
301 async def purge_room(self, room_id: str) -> None:
272302 """Purge the given room from the database"""
273 with (await self.pagination_lock.write(room_id)):
303 with await self.pagination_lock.write(room_id):
274304 # check we know about the room
275305 await self.store.get_room_version_id(room_id)
276306
284314
285315 async def get_messages(
286316 self,
287 requester,
288 room_id=None,
289 pagin_config=None,
290 as_client_event=True,
291 event_filter=None,
292 ):
317 requester: Requester,
318 room_id: str,
319 pagin_config: PaginationConfig,
320 as_client_event: bool = True,
321 event_filter: Optional[Filter] = None,
322 ) -> Dict[str, Any]:
293323 """Get messages in a room.
294324
295325 Args:
296 requester (Requester): The user requesting messages.
297 room_id (str): The room they want messages from.
298 pagin_config (synapse.api.streams.PaginationConfig): The pagination
299 config rules to apply, if any.
300 as_client_event (bool): True to get events in client-server format.
301 event_filter (Filter): Filter to apply to results or None
326 requester: The user requesting messages.
327 room_id: The room they want messages from.
328 pagin_config: The pagination config rules to apply, if any.
329 as_client_event: True to get events in client-server format.
330 event_filter: Filter to apply to results or None
302331 Returns:
303 dict: Pagination API results
332 Pagination API results
304333 """
305334 user_id = requester.user.to_string()
306335
320349
321350 source_config = pagin_config.get_source_config("room")
322351
323 with (await self.pagination_lock.read(room_id)):
352 with await self.pagination_lock.read(room_id):
324353 (
325354 membership,
326355 member_event_id,
332361 # if we're going backwards, we might need to backfill. This
333362 # requires that we have a topo token.
334363 if room_token.topological:
335 max_topo = room_token.topological
364 curr_topo = room_token.topological
336365 else:
337 max_topo = await self.store.get_max_topological_token(
366 curr_topo = await self.store.get_current_topological_token(
338367 room_id, room_token.stream
339368 )
340369
342371 # If they have left the room then clamp the token to be before
343372 # they left the room, to save the effort of loading from the
344373 # database.
374
375 # This is only None if the room is world_readable, in which
376 # case "JOIN" would have been returned.
377 assert member_event_id
378
345379 leave_token = await self.store.get_topological_token_for_event(
346380 member_event_id
347381 )
348 leave_token = RoomStreamToken.parse(leave_token)
349 if leave_token.topological < max_topo:
382 if RoomStreamToken.parse(leave_token).topological < curr_topo:
350383 source_config.from_key = str(leave_token)
351384
352385 await self.hs.get_handlers().federation_handler.maybe_backfill(
353 room_id, max_topo
386 room_id, curr_topo, limit=source_config.limit,
354387 )
355388
356389 events, next_key = await self.store.paginate_room_events(
393426 )
394427
395428 if state_ids:
396 state = await self.store.get_events(list(state_ids.values()))
397 state = state.values()
429 state_dict = await self.store.get_events(list(state_ids.values()))
430 state = state_dict.values()
398431
399432 time_now = self.clock.time_msec()
400433
2121 logger = logging.getLogger(__name__)
2222
2323
24 class PasswordPolicyHandler(object):
24 class PasswordPolicyHandler:
2525 def __init__(self, hs):
2626 self.policy = hs.config.password_policy
2727 self.enabled = hs.config.password_policy_enabled
3232 import synapse.metrics
3333 from synapse.api.constants import EventTypes, Membership, PresenceState
3434 from synapse.api.errors import SynapseError
35 from synapse.api.presence import UserPresenceState
3536 from synapse.logging.context import run_in_background
3637 from synapse.logging.utils import log_function
3738 from synapse.metrics import LaterGauge
3839 from synapse.metrics.background_process_metrics import run_as_background_process
3940 from synapse.state import StateHandler
4041 from synapse.storage.databases.main import DataStore
41 from synapse.storage.presence import UserPresenceState
42 from synapse.types import JsonDict, UserID, get_domain_from_id
42 from synapse.types import Collection, JsonDict, UserID, get_domain_from_id
4343 from synapse.util.async_helpers import Linearizer
4444 from synapse.util.caches.descriptors import cached
4545 from synapse.util.metrics import Measure
10091009 return content
10101010
10111011
1012 class PresenceEventSource(object):
1012 class PresenceEventSource:
10131013 def __init__(self, hs):
10141014 # We can't call get_presence_handler here because there's a cycle:
10151015 #
13171317
13181318 async def get_interested_remotes(
13191319 store: DataStore, states: List[UserPresenceState], state_handler: StateHandler
1320 ) -> List[Tuple[List[str], List[UserPresenceState]]]:
1320 ) -> List[Tuple[Collection[str], List[UserPresenceState]]]:
13211321 """Given a list of presence states figure out which remote servers
13221322 should be sent which.
13231323
13331333 each tuple the list of UserPresenceState should be sent to each
13341334 destination
13351335 """
1336 hosts_and_states = []
1336 hosts_and_states = [] # type: List[Tuple[Collection[str], List[UserPresenceState]]]
13371337
13381338 # First we look up the rooms each user is in (as well as any explicit
13391339 # subscriptions), then for each distinct room we look up the remote
1313 # limitations under the License.
1414
1515 import logging
16 import random
1617
1718 from synapse.api.errors import (
1819 AuthError,
159160 Codes.FORBIDDEN,
160161 )
161162
163 if not isinstance(new_displayname, str):
164 raise SynapseError(400, "Invalid displayname")
165
162166 if len(new_displayname) > MAX_DISPLAYNAME_LEN:
163167 raise SynapseError(
164168 400, "Displayname is too long (max %i)" % (MAX_DISPLAYNAME_LEN,)
212216 async def set_avatar_url(
213217 self, target_user, requester, new_avatar_url, by_admin=False
214218 ):
215 """target_user is the user whose avatar_url is to be changed;
216 auth_user is the user attempting to make this change."""
219 """Set a new avatar URL for a user.
220
221 Args:
222 target_user (UserID): the user whose avatar URL is to be changed.
223 requester (Requester): The user attempting to make this change.
224 new_avatar_url (str): The avatar URL to give this user.
225 by_admin (bool): Whether this change was made by an administrator.
226 """
217227 if not self.hs.is_mine(target_user):
218228 raise SynapseError(400, "User is not hosted on this homeserver")
219229
227237 400, "Changing avatar is disabled on this server", Codes.FORBIDDEN
228238 )
229239
240 if not isinstance(new_avatar_url, str):
241 raise SynapseError(400, "Invalid displayname")
242
230243 if len(new_avatar_url) > MAX_AVATAR_URL_LEN:
231244 raise SynapseError(
232245 400, "Avatar URL is too long (max %i)" % (MAX_AVATAR_URL_LEN,)
276289 return
277290
278291 await self.ratelimit(requester)
292
293 # Do not actually update the room state for shadow-banned users.
294 if requester.shadow_banned:
295 # We randomly sleep a bit just to annoy the requester.
296 await self.clock.sleep(random.randint(1, 10))
297 return
279298
280299 room_ids = await self.store.get_rooms_for_user(target_user.to_string())
281300
122122 await self.federation.send_read_receipt(receipt)
123123
124124
125 class ReceiptEventSource(object):
125 class ReceiptEventSource:
126126 def __init__(self, hs):
127127 self.store = hs.get_datastore()
128128
2525 ReplicationPostRegisterActionsServlet,
2626 ReplicationRegisterServlet,
2727 )
28 from synapse.spam_checker_api import RegistrationBehaviour
2829 from synapse.storage.state import StateFilter
2930 from synapse.types import RoomAlias, UserID, create_requester
3031
5152 self.macaroon_gen = hs.get_macaroon_generator()
5253 self._server_notices_mxid = hs.config.server_notices_mxid
5354
55 self.spam_checker = hs.get_spam_checker()
56
5457 if hs.config.worker_app:
5558 self._register_client = ReplicationRegisterServlet.make_client(hs)
5659 self._register_device_client = RegisterDeviceReplicationServlet.make_client(
123126 try:
124127 int(localpart)
125128 raise SynapseError(
126 400, "Numeric user IDs are reserved for guest users."
129 400,
130 "Numeric user IDs are reserved for guest users.",
131 errcode=Codes.INVALID_USERNAME,
127132 )
128133 except ValueError:
129134 pass
141146 address=None,
142147 bind_emails=[],
143148 by_admin=False,
149 user_agent_ips=None,
144150 ):
145151 """Registers a new client on the server.
146152
158164 bind_emails (List[str]): list of emails to bind to this account.
159165 by_admin (bool): True if this registration is being made via the
160166 admin api, otherwise False.
167 user_agent_ips (List[(str, str)]): Tuples of IP addresses and user-agents used
168 during the registration process.
161169 Returns:
162170 str: user_id
163171 Raises:
164172 SynapseError if there was a problem registering.
165173 """
166174 self.check_registration_ratelimit(address)
175
176 result = self.spam_checker.check_registration_for_spam(
177 threepid, localpart, user_agent_ips or [],
178 )
179
180 if result == RegistrationBehaviour.DENY:
181 logger.info(
182 "Blocked registration of %r", localpart,
183 )
184 # We return a 429 to make it not obvious that they've been
185 # denied.
186 raise SynapseError(429, "Rate limited")
187
188 shadow_banned = result == RegistrationBehaviour.SHADOW_BAN
189 if shadow_banned:
190 logger.info(
191 "Shadow banning registration of %r", localpart,
192 )
167193
168194 # do not check_auth_blocking if the call is coming through the Admin API
169195 if not by_admin:
193219 admin=admin,
194220 user_type=user_type,
195221 address=address,
222 shadow_banned=shadow_banned,
196223 )
197224
198225 if self.hs.config.user_directory_search_all_users:
223250 make_guest=make_guest,
224251 create_profile_with_displayname=default_display_name,
225252 address=address,
253 shadow_banned=shadow_banned,
226254 )
227255
228256 # Successfully registered
528556 admin=False,
529557 user_type=None,
530558 address=None,
559 shadow_banned=False,
531560 ):
532561 """Register user in the datastore.
533562
545574 user_type (str|None): type of user. One of the values from
546575 api.constants.UserTypes, or None for a normal user.
547576 address (str|None): the IP address used to perform the registration.
577 shadow_banned (bool): Whether to shadow-ban the user
548578
549579 Returns:
550580 Awaitable
560590 admin=admin,
561591 user_type=user_type,
562592 address=address,
593 shadow_banned=shadow_banned,
563594 )
564595 else:
565596 return self.store.register_user(
571602 create_profile_with_displayname=create_profile_with_displayname,
572603 admin=admin,
573604 user_type=user_type,
605 shadow_banned=shadow_banned,
574606 )
575607
576608 async def register_device(
1919 import itertools
2020 import logging
2121 import math
22 import random
2223 import string
2324 from collections import OrderedDict
24 from typing import Awaitable, Optional, Tuple
25 from typing import TYPE_CHECKING, Any, Awaitable, Dict, List, Optional, Tuple
2526
2627 from synapse.api.constants import (
2728 EventTypes,
3132 RoomEncryptionAlgorithms,
3233 )
3334 from synapse.api.errors import AuthError, Codes, NotFoundError, StoreError, SynapseError
35 from synapse.api.filtering import Filter
3436 from synapse.api.room_versions import KNOWN_ROOM_VERSIONS, RoomVersion
37 from synapse.events import EventBase
3538 from synapse.events.utils import copy_power_levels_contents
3639 from synapse.http.endpoint import parse_and_validate_server_name
3740 from synapse.storage.state import StateFilter
3841 from synapse.types import (
42 JsonDict,
43 MutableStateMap,
3944 Requester,
4045 RoomAlias,
4146 RoomID,
4651 create_requester,
4752 )
4853 from synapse.util import stringutils
49 from synapse.util.async_helpers import Linearizer, maybe_awaitable
54 from synapse.util.async_helpers import Linearizer
5055 from synapse.util.caches.response_cache import ResponseCache
5156 from synapse.visibility import filter_events_for_client
5257
5358 from ._base import BaseHandler
5459
60 if TYPE_CHECKING:
61 from synapse.server import HomeServer
62
5563 logger = logging.getLogger(__name__)
5664
5765 id_server_scheme = "https://"
6068
6169
6270 class RoomCreationHandler(BaseHandler):
63 def __init__(self, hs):
71 def __init__(self, hs: "HomeServer"):
6472 super(RoomCreationHandler, self).__init__(hs)
6573
6674 self.spam_checker = hs.get_spam_checker()
9199 "guest_can_join": False,
92100 "power_level_content_override": {},
93101 },
94 }
102 } # type: Dict[str, Dict[str, Any]]
95103
96104 # Modify presets to selectively enable encryption by default per homeserver config
97105 for preset_name, preset_config in self._presets_dict.items():
128136
129137 Returns:
130138 the new room id
139
140 Raises:
141 ShadowBanError if the requester is shadow-banned.
131142 """
132143 await self.ratelimit(requester)
133144
163174 async def _upgrade_room(
164175 self, requester: Requester, old_room_id: str, new_version: RoomVersion
165176 ):
177 """
178 Args:
179 requester: the user requesting the upgrade
180 old_room_id: the id of the room to be replaced
181 new_versions: the version to upgrade the room to
182
183 Raises:
184 ShadowBanError if the requester is shadow-banned.
185 """
166186 user_id = requester.user.to_string()
167187
168188 # start by allocating a new room id
214234
215235 old_room_state = await tombstone_context.get_current_state_ids()
216236
237 # We know the tombstone event isn't an outlier so it has current state.
238 assert old_room_state is not None
239
217240 # update any aliases
218241 await self._move_aliases_to_new_room(
219242 requester, old_room_id, new_room_id, old_room_state
246269 old_room_id: the id of the room to be replaced
247270 new_room_id: the id of the replacement room
248271 old_room_state: the state map for the old room
272
273 Raises:
274 ShadowBanError if the requester is shadow-banned.
249275 """
250276 old_room_pl_event_id = old_room_state.get((EventTypes.PowerLevels, ""))
251277
424450 old_room_member_state_events = await self.store.get_events(
425451 old_room_member_state_ids.values()
426452 )
427 for k, old_event in old_room_member_state_events.items():
453 for old_event in old_room_member_state_events.values():
428454 # Only transfer ban events
429455 if (
430456 "membership" in old_event.content
527553 logger.error("Unable to send updated alias events in new room: %s", e)
528554
529555 async def create_room(
530 self, requester, config, ratelimit=True, creator_join_profile=None
556 self,
557 requester: Requester,
558 config: JsonDict,
559 ratelimit: bool = True,
560 creator_join_profile: Optional[JsonDict] = None,
531561 ) -> Tuple[dict, int]:
532562 """ Creates a new room.
533563
534564 Args:
535 requester (synapse.types.Requester):
565 requester:
536566 The user who requested the room creation.
537 config (dict) : A dict of configuration options.
538 ratelimit (bool): set to False to disable the rate limiter
539
540 creator_join_profile (dict|None):
567 config : A dict of configuration options.
568 ratelimit: set to False to disable the rate limiter
569
570 creator_join_profile:
541571 Set to override the displayname and avatar for the creating
542572 user in this room. If unset, displayname and avatar will be
543573 derived from the user's profile. If set, should contain the
600630 Codes.UNSUPPORTED_ROOM_VERSION,
601631 )
602632
633 room_alias = None
603634 if "room_alias_name" in config:
604635 for wchar in string.whitespace:
605636 if wchar in config["room_alias_name"]:
610641
611642 if mapping:
612643 raise SynapseError(400, "Room alias already taken", Codes.ROOM_IN_USE)
613 else:
614 room_alias = None
615
644
645 invite_3pid_list = config.get("invite_3pid", [])
616646 invite_list = config.get("invite", [])
617647 for i in invite_list:
618648 try:
620650 parse_and_validate_server_name(uid.domain)
621651 except Exception:
622652 raise SynapseError(400, "Invalid user_id: %s" % (i,))
653
654 if (invite_list or invite_3pid_list) and requester.shadow_banned:
655 # We randomly sleep a bit just to annoy the requester.
656 await self.clock.sleep(random.randint(1, 10))
657
658 # Allow the request to go through, but remove any associated invites.
659 invite_3pid_list = []
660 invite_list = []
623661
624662 await self.event_creation_handler.assert_accepted_privacy_policy(requester)
625663
634672 "Not a valid power_level_content_override: 'users' did not contain %s"
635673 % (user_id,),
636674 )
637
638 invite_3pid_list = config.get("invite_3pid", [])
639675
640676 visibility = config.get("visibility", None)
641677 is_public = visibility == "public"
731767 if is_direct:
732768 content["is_direct"] = is_direct
733769
770 # Note that update_membership with an action of "invite" can raise a
771 # ShadowBanError, but this was handled above by emptying invite_list.
734772 _, last_stream_id = await self.room_member_handler.update_membership(
735773 requester,
736774 UserID.from_string(invitee),
745783 id_access_token = invite_3pid.get("id_access_token") # optional
746784 address = invite_3pid["address"]
747785 medium = invite_3pid["medium"]
786 # Note that do_3pid_invite can raise a ShadowBanError, but this was
787 # handled above by emptying invite_3pid_list.
748788 last_stream_id = await self.hs.get_room_member_handler().do_3pid_invite(
749789 room_id,
750790 requester.user,
770810
771811 async def _send_events_for_new_room(
772812 self,
773 creator, # A Requester object.
774 room_id,
775 preset_config,
776 invite_list,
777 initial_state,
778 creation_content,
779 room_alias=None,
780 power_level_content_override=None, # Doesn't apply when initial state has power level state event content
781 creator_join_profile=None,
813 creator: Requester,
814 room_id: str,
815 preset_config: str,
816 invite_list: List[str],
817 initial_state: MutableStateMap,
818 creation_content: JsonDict,
819 room_alias: Optional[RoomAlias] = None,
820 power_level_content_override: Optional[JsonDict] = None,
821 creator_join_profile: Optional[JsonDict] = None,
782822 ) -> int:
783823 """Sends the initial events into a new room.
824
825 `power_level_content_override` doesn't apply when initial state has
826 power level state event content.
784827
785828 Returns:
786829 The stream_id of the last event persisted.
787830 """
788831
789 def create(etype, content, **kwargs):
832 creator_id = creator.user.to_string()
833
834 event_keys = {"room_id": room_id, "sender": creator_id, "state_key": ""}
835
836 def create(etype: str, content: JsonDict, **kwargs) -> JsonDict:
790837 e = {"type": etype, "content": content}
791838
792839 e.update(event_keys)
794841
795842 return e
796843
797 async def send(etype, content, **kwargs) -> int:
844 async def send(etype: str, content: JsonDict, **kwargs) -> int:
798845 event = create(etype, content, **kwargs)
799846 logger.debug("Sending %s in new room", etype)
847 # Allow these events to be sent even if the user is shadow-banned to
848 # allow the room creation to complete.
800849 (
801850 _,
802851 last_stream_id,
803852 ) = await self.event_creation_handler.create_and_send_nonmember_event(
804 creator, event, ratelimit=False
853 creator, event, ratelimit=False, ignore_shadow_ban=True,
805854 )
806855 return last_stream_id
807856
808857 config = self._presets_dict[preset_config]
809
810 creator_id = creator.user.to_string()
811
812 event_keys = {"room_id": room_id, "sender": creator_id, "state_key": ""}
813858
814859 creation_content.update({"creator": creator_id})
815860 await send(etype=EventTypes.Create, content=creation_content)
851896 "kick": 50,
852897 "redact": 50,
853898 "invite": 50,
854 }
899 } # type: JsonDict
855900
856901 if config["original_invitees_have_ops"]:
857902 for invitee in invite_list:
905950 return last_sent_stream_id
906951
907952 async def _generate_room_id(
908 self, creator_id: str, is_public: str, room_version: RoomVersion,
953 self, creator_id: str, is_public: bool, room_version: RoomVersion,
909954 ):
910955 # autogen room IDs and try to create it. We may clash, so just
911956 # try a few times till one goes through, giving up eventually.
928973 raise StoreError(500, "Couldn't generate a room ID.")
929974
930975
931 class RoomContextHandler(object):
932 def __init__(self, hs):
976 class RoomContextHandler:
977 def __init__(self, hs: "HomeServer"):
933978 self.hs = hs
934979 self.store = hs.get_datastore()
935980 self.storage = hs.get_storage()
936981 self.state_store = self.storage.state
937982
938 async def get_event_context(self, user, room_id, event_id, limit, event_filter):
983 async def get_event_context(
984 self,
985 user: UserID,
986 room_id: str,
987 event_id: str,
988 limit: int,
989 event_filter: Optional[Filter],
990 ) -> Optional[JsonDict]:
939991 """Retrieves events, pagination tokens and state around a given event
940992 in a room.
941993
942994 Args:
943 user (UserID)
944 room_id (str)
945 event_id (str)
946 limit (int): The maximum number of events to return in total
995 user
996 room_id
997 event_id
998 limit: The maximum number of events to return in total
947999 (excluding state).
948 event_filter (Filter|None): the filter to apply to the events returned
1000 event_filter: the filter to apply to the events returned
9491001 (excluding the target event_id)
9501002
9511003 Returns:
10311083 return results
10321084
10331085
1034 class RoomEventSource(object):
1035 def __init__(self, hs):
1086 class RoomEventSource:
1087 def __init__(self, hs: "HomeServer"):
10361088 self.store = hs.get_datastore()
10371089
10381090 async def get_new_events(
1039 self, user, from_key, limit, room_ids, is_guest, explicit_room_id=None
1040 ):
1091 self,
1092 user: UserID,
1093 from_key: str,
1094 limit: int,
1095 room_ids: List[str],
1096 is_guest: bool,
1097 explicit_room_id: Optional[str] = None,
1098 ) -> Tuple[List[EventBase], str]:
10411099 # We just ignore the key for now.
10421100
10431101 to_key = self.get_current_key()
10871145 return self.store.get_room_events_max_id(room_id)
10881146
10891147
1090 class RoomShutdownHandler(object):
1148 class RoomShutdownHandler:
10911149
10921150 DEFAULT_MESSAGE = (
10931151 "Sharing illegal content on this server is not permitted and rooms in"
10951153 )
10961154 DEFAULT_ROOM_NAME = "Content Violation Notification"
10971155
1098 def __init__(self, hs):
1156 def __init__(self, hs: "HomeServer"):
10991157 self.hs = hs
11001158 self.room_member_handler = hs.get_room_member_handler()
11011159 self._room_creation_handler = hs.get_room_creation_handler()
12711329 ratelimit=False,
12721330 )
12731331
1274 aliases_for_room = await maybe_awaitable(
1275 self.store.get_aliases_for_room(room_id)
1276 )
1332 aliases_for_room = await self.store.get_aliases_for_room(room_id)
12771333
12781334 await self.store.update_aliases_for_room(
12791335 room_id, new_room_id, requester_user_id
1414
1515 import abc
1616 import logging
17 import random
1718 from http import HTTPStatus
18 from typing import TYPE_CHECKING, Dict, Iterable, List, Optional, Tuple, Union
19 from typing import TYPE_CHECKING, Iterable, List, Optional, Tuple, Union
1920
2021 from unpaddedbase64 import encode_base64
2122
2223 from synapse import types
2324 from synapse.api.constants import MAX_DEPTH, EventTypes, Membership
24 from synapse.api.errors import AuthError, Codes, LimitExceededError, SynapseError
25 from synapse.api.errors import (
26 AuthError,
27 Codes,
28 LimitExceededError,
29 ShadowBanError,
30 SynapseError,
31 )
2532 from synapse.api.ratelimiting import Ratelimiter
2633 from synapse.api.room_versions import EventFormatVersions
2734 from synapse.crypto.event_signing import compute_event_reference_hash
3037 from synapse.events.snapshot import EventContext
3138 from synapse.events.validator import EventValidator
3239 from synapse.storage.roommember import RoomsForUser
33 from synapse.types import Collection, JsonDict, Requester, RoomAlias, RoomID, UserID
40 from synapse.types import JsonDict, Requester, RoomAlias, RoomID, StateMap, UserID
3441 from synapse.util.async_helpers import Linearizer
3542 from synapse.util.distributor import user_joined_room, user_left_room
3643
4350 logger = logging.getLogger(__name__)
4451
4552
46 class RoomMemberHandler(object):
53 class RoomMemberHandler:
4754 # TODO(paul): This handler currently contains a messy conflation of
4855 # low-level API that works on UserID objects and so on, and REST-level
4956 # API that takes ID strings and returns pagination chunks. These concerns
168175 target: UserID,
169176 room_id: str,
170177 membership: str,
171 prev_event_ids: Collection[str],
178 prev_event_ids: List[str],
172179 txn_id: Optional[str] = None,
173180 ratelimit: bool = True,
174181 content: Optional[dict] = None,
300307 content: Optional[dict] = None,
301308 require_consent: bool = True,
302309 ) -> Tuple[str, int]:
310 """Update a user's membership in a room.
311
312 Params:
313 requester: The user who is performing the update.
314 target: The user whose membership is being updated.
315 room_id: The room ID whose membership is being updated.
316 action: The membership change, see synapse.api.constants.Membership.
317 txn_id: The transaction ID, if given.
318 remote_room_hosts: Remote servers to send the update to.
319 third_party_signed: Information from a 3PID invite.
320 ratelimit: Whether to rate limit the request.
321 content: The content of the created event.
322 require_consent: Whether consent is required.
323
324 Returns:
325 A tuple of the new event ID and stream ID.
326
327 Raises:
328 ShadowBanError if a shadow-banned requester attempts to send an invite.
329 """
330 if action == Membership.INVITE and requester.shadow_banned:
331 # We randomly sleep a bit just to annoy the requester.
332 await self.clock.sleep(random.randint(1, 10))
333 raise ShadowBanError()
334
303335 key = (room_id,)
304336
305337 with (await self.member_linearizer.queue(key)):
339371 # later on.
340372 content = dict(content)
341373
342 if not self.allow_per_room_profiles:
374 if not self.allow_per_room_profiles or requester.shadow_banned:
343375 # Strip profile data, knowing that new profile data will be added to the
344376 # event's content in event_creation_handler.create_event() using the target's
345377 # global profile.
709741 if prev_member_event.membership == Membership.JOIN:
710742 await self._user_left_room(target_user, room_id)
711743
712 async def _can_guest_join(
713 self, current_state_ids: Dict[Tuple[str, str], str]
714 ) -> bool:
744 async def _can_guest_join(self, current_state_ids: StateMap[str]) -> bool:
715745 """
716746 Returns whether a guest can join a room based on its current state.
717747 """
721751
722752 guest_access = await self.store.get_event(guest_access_id)
723753
724 return (
754 return bool(
725755 guest_access
726756 and guest_access.content
727757 and "guest_access" in guest_access.content
778808 txn_id: Optional[str],
779809 id_access_token: Optional[str] = None,
780810 ) -> int:
811 """Invite a 3PID to a room.
812
813 Args:
814 room_id: The room to invite the 3PID to.
815 inviter: The user sending the invite.
816 medium: The 3PID's medium.
817 address: The 3PID's address.
818 id_server: The identity server to use.
819 requester: The user making the request.
820 txn_id: The transaction ID this is part of, or None if this is not
821 part of a transaction.
822 id_access_token: The optional identity server access token.
823
824 Returns:
825 The new stream ID.
826
827 Raises:
828 ShadowBanError if the requester has been shadow-banned.
829 """
781830 if self.config.block_non_admin_invites:
782831 is_requester_admin = await self.auth.is_server_admin(requester.user)
783832 if not is_requester_admin:
785834 403, "Invites have been disabled on this server", Codes.FORBIDDEN
786835 )
787836
837 if requester.shadow_banned:
838 # We randomly sleep a bit just to annoy the requester.
839 await self.clock.sleep(random.randint(1, 10))
840 raise ShadowBanError()
841
788842 # We need to rate limit *before* we send out any 3PID invites, so we
789843 # can't just rely on the standard ratelimiting of events.
790844 await self.base_handler.ratelimit(requester)
809863 )
810864
811865 if invitee:
866 # Note that update_membership with an action of "invite" can raise
867 # a ShadowBanError, but this was done above already.
812868 _, stream_id = await self.update_membership(
813869 requester, UserID.from_string(invitee), room_id, "invite", txn_id=txn_id
814870 )
914970 )
915971 return stream_id
916972
917 async def _is_host_in_room(
918 self, current_state_ids: Dict[Tuple[str, str], str]
919 ) -> bool:
973 async def _is_host_in_room(self, current_state_ids: StateMap[str]) -> bool:
920974 # Have we just created the room, and is this about to be the very
921975 # first member event?
922976 create_event_id = current_state_ids.get(("m.room.create", ""))
10471101 return event_id, stream_id
10481102
10491103 # The room is too large. Leave.
1050 requester = types.create_requester(user, None, False, None)
1104 requester = types.create_requester(user, None, False, False, None)
10511105 await self.update_membership(
10521106 requester=requester, target=user, room_id=room_id, action="leave"
10531107 )
5353
5454 class SamlHandler:
5555 def __init__(self, hs: "synapse.server.HomeServer"):
56 self.hs = hs
5657 self._saml_client = Saml2Client(hs.config.saml2_sp_config)
5758 self._auth = hs.get_auth()
5859 self._auth_handler = hs.get_auth_handler()
132133 # the dict.
133134 self.expire_sessions()
134135
136 # Pull out the user-agent and IP from the request.
137 user_agent = request.requestHeaders.getRawHeaders(b"User-Agent", default=[b""])[
138 0
139 ].decode("ascii", "surrogateescape")
140 ip_address = self.hs.get_ip_from_request(request)
141
135142 user_id, current_session = await self._map_saml_response_to_user(
136 resp_bytes, relay_state
143 resp_bytes, relay_state, user_agent, ip_address
137144 )
138145
139146 # Complete the interactive auth session or the login.
146153 await self._auth_handler.complete_sso_login(user_id, request, relay_state)
147154
148155 async def _map_saml_response_to_user(
149 self, resp_bytes: str, client_redirect_url: str
156 self,
157 resp_bytes: str,
158 client_redirect_url: str,
159 user_agent: str,
160 ip_address: str,
150161 ) -> Tuple[str, Optional[Saml2SessionData]]:
151162 """
152163 Given a sample response, retrieve the cached session and user for it.
154165 Args:
155166 resp_bytes: The SAML response.
156167 client_redirect_url: The redirect URL passed in by the client.
168 user_agent: The user agent of the client making the request.
169 ip_address: The IP address of the client making the request.
157170
158171 Returns:
159172 Tuple of the user ID and SAML session associated with this response.
290303 localpart=localpart,
291304 default_display_name=displayname,
292305 bind_emails=emails,
306 user_agent_ips=(user_agent, ip_address),
293307 )
294308
295309 await self._datastore.record_user_external_id(
345359
346360
347361 @attr.s
348 class SamlConfig(object):
362 class SamlConfig:
349363 mxid_source_attribute = attr.ib()
350364 mxid_mapper = attr.ib()
351365
352366
353 class DefaultSamlMappingProvider(object):
367 class DefaultSamlMappingProvider:
354368 __version__ = "0.0.1"
355369
356370 def __init__(self, parsed_config: SamlConfig, module_api: ModuleApi):
1717 logger = logging.getLogger(__name__)
1818
1919
20 class StateDeltasHandler(object):
20 class StateDeltasHandler:
2121 def __init__(self, hs):
2222 self.store = hs.get_datastore()
2323
1515
1616 import itertools
1717 import logging
18 from typing import Any, Dict, FrozenSet, List, Optional, Set, Tuple
18 from typing import TYPE_CHECKING, Any, Dict, FrozenSet, List, Optional, Set, Tuple
1919
2020 import attr
2121 from prometheus_client import Counter
3030 from synapse.types import (
3131 Collection,
3232 JsonDict,
33 MutableStateMap,
3334 RoomStreamToken,
3435 StateMap,
3536 StreamToken,
4243 from synapse.util.metrics import Measure, measure_func
4344 from synapse.visibility import filter_events_for_client
4445
46 if TYPE_CHECKING:
47 from synapse.server import HomeServer
48
4549 logger = logging.getLogger(__name__)
4650
4751 # Debug logger for https://github.com/matrix-org/synapse/issues/4422
9397 __bool__ = __nonzero__ # python3
9498
9599
96 @attr.s(slots=True, frozen=True)
100 # We can't freeze this class, because we need to update it after it's instantiated to
101 # update its unread count. This is because we calculate the unread count for a room only
102 # if there are updates for it, which we check after the instance has been created.
103 # This should not be a big deal because we update the notification counts afterwards as
104 # well anyway.
105 @attr.s(slots=True)
97106 class JoinedSyncResult:
98107 room_id = attr.ib(type=str)
99108 timeline = attr.ib(type=TimelineBatch)
102111 account_data = attr.ib(type=List[JsonDict])
103112 unread_notifications = attr.ib(type=JsonDict)
104113 summary = attr.ib(type=Optional[JsonDict])
114 unread_count = attr.ib(type=int)
105115
106116 def __nonzero__(self) -> bool:
107117 """Make the result appear empty if there are no updates. This is used
235245 __bool__ = __nonzero__ # python3
236246
237247
238 class SyncHandler(object):
239 def __init__(self, hs):
248 class SyncHandler:
249 def __init__(self, hs: "HomeServer"):
240250 self.hs_config = hs.config
241251 self.store = hs.get_datastore()
242252 self.notifier = hs.get_notifier()
587597 room_id: str,
588598 sync_config: SyncConfig,
589599 batch: TimelineBatch,
590 state: StateMap[EventBase],
600 state: MutableStateMap[EventBase],
591601 now_token: StreamToken,
592602 ) -> Optional[JsonDict]:
593603 """ Works out a room summary block for this room, summarising the number
709719 ]
710720
711721 missing_hero_state = await self.store.get_events(missing_hero_event_ids)
712 missing_hero_state = missing_hero_state.values()
713
714 for s in missing_hero_state:
722
723 for s in missing_hero_state.values():
715724 cache.set(s.state_key, s.event_id)
716725 state[(EventTypes.Member, s.state_key)] = s
717726
735744 since_token: Optional[StreamToken],
736745 now_token: StreamToken,
737746 full_state: bool,
738 ) -> StateMap[EventBase]:
747 ) -> MutableStateMap[EventBase]:
739748 """ Works out the difference in state between the start of the timeline
740749 and the previous sync.
741750
929938
930939 async def unread_notifs_for_room_id(
931940 self, room_id: str, sync_config: SyncConfig
932 ) -> Optional[Dict[str, str]]:
941 ) -> Dict[str, int]:
933942 with Measure(self.clock, "unread_notifs_for_room_id"):
934943 last_unread_event_id = await self.store.get_last_receipt_event_id_for_user(
935944 user_id=sync_config.user.to_string(),
937946 receipt_type="m.read",
938947 )
939948
940 if last_unread_event_id:
941 notifs = await self.store.get_unread_event_push_actions_by_room_for_user(
942 room_id, sync_config.user.to_string(), last_unread_event_id
943 )
944 return notifs
945
946 # There is no new information in this period, so your notification
947 # count is whatever it was last time.
948 return None
949 notifs = await self.store.get_unread_event_push_actions_by_room_for_user(
950 room_id, sync_config.user.to_string(), last_unread_event_id
951 )
952 return notifs
949953
950954 async def generate_sync_result(
951955 self,
17681772 ignored_users: Set[str],
17691773 room_builder: "RoomSyncResultBuilder",
17701774 ephemeral: List[JsonDict],
1771 tags: Optional[List[JsonDict]],
1775 tags: Optional[Dict[str, Dict[str, Any]]],
17721776 account_data: Dict[str, JsonDict],
17731777 always_include: bool = False,
17741778 ):
18841888 )
18851889
18861890 if room_builder.rtype == "joined":
1887 unread_notifications = {} # type: Dict[str, str]
1891 unread_notifications = {} # type: Dict[str, int]
18881892 room_sync = JoinedSyncResult(
18891893 room_id=room_id,
18901894 timeline=batch,
18931897 account_data=account_data_events,
18941898 unread_notifications=unread_notifications,
18951899 summary=summary,
1900 unread_count=0,
18961901 )
18971902
18981903 if room_sync or always_include:
18991904 notifs = await self.unread_notifs_for_room_id(room_id, sync_config)
19001905
1901 if notifs is not None:
1902 unread_notifications["notification_count"] = notifs["notify_count"]
1903 unread_notifications["highlight_count"] = notifs["highlight_count"]
1906 unread_notifications["notification_count"] = notifs["notify_count"]
1907 unread_notifications["highlight_count"] = notifs["highlight_count"]
1908
1909 room_sync.unread_count = notifs["unread_count"]
19041910
19051911 sync_result_builder.joined.append(room_sync)
19061912
20682074
20692075
20702076 @attr.s
2071 class RoomSyncResultBuilder(object):
2077 class RoomSyncResultBuilder:
20722078 """Stores information needed to create either a `JoinedSyncResult` or
20732079 `ArchivedSyncResult`.
20742080
1313 # limitations under the License.
1414
1515 import logging
16 import random
1617 from collections import namedtuple
1718 from typing import TYPE_CHECKING, List, Set, Tuple
1819
19 from synapse.api.errors import AuthError, SynapseError
20 from synapse.api.errors import AuthError, ShadowBanError, SynapseError
2021 from synapse.metrics.background_process_metrics import run_as_background_process
2122 from synapse.replication.tcp.streams import TypingStream
2223 from synapse.types import UserID, get_domain_from_id
226227 self._stopped_typing(member)
227228 return
228229
229 async def started_typing(self, target_user, auth_user, room_id, timeout):
230 async def started_typing(self, target_user, requester, room_id, timeout):
230231 target_user_id = target_user.to_string()
231 auth_user_id = auth_user.to_string()
232 auth_user_id = requester.user.to_string()
232233
233234 if not self.is_mine_id(target_user_id):
234235 raise SynapseError(400, "User is not hosted on this homeserver")
235236
236237 if target_user_id != auth_user_id:
237238 raise AuthError(400, "Cannot set another user's typing state")
239
240 if requester.shadow_banned:
241 # We randomly sleep a bit just to annoy the requester.
242 await self.clock.sleep(random.randint(1, 10))
243 raise ShadowBanError()
238244
239245 await self.auth.check_user_in_room(room_id, target_user_id)
240246
255261
256262 self._push_update(member=member, typing=True)
257263
258 async def stopped_typing(self, target_user, auth_user, room_id):
264 async def stopped_typing(self, target_user, requester, room_id):
259265 target_user_id = target_user.to_string()
260 auth_user_id = auth_user.to_string()
266 auth_user_id = requester.user.to_string()
261267
262268 if not self.is_mine_id(target_user_id):
263269 raise SynapseError(400, "User is not hosted on this homeserver")
264270
265271 if target_user_id != auth_user_id:
266272 raise AuthError(400, "Cannot set another user's typing state")
273
274 if requester.shadow_banned:
275 # We randomly sleep a bit just to annoy the requester.
276 await self.clock.sleep(random.randint(1, 10))
277 raise ShadowBanError()
267278
268279 await self.auth.check_user_in_room(room_id, target_user_id)
269280
400411 raise Exception("Typing writer instance got typing info over replication")
401412
402413
403 class TypingNotificationEventSource(object):
414 class TypingNotificationEventSource:
404415 def __init__(self, hs):
405416 self.hs = hs
406417 self.clock = hs.get_clock()
1515 import logging
1616 from typing import Any
1717
18 from canonicaljson import json
19
2018 from twisted.web.client import PartialDownloadError
2119
2220 from synapse.api.constants import LoginType
2321 from synapse.api.errors import Codes, LoginError, SynapseError
2422 from synapse.config.emailconfig import ThreepidBehaviour
23 from synapse.util import json_decoder
2524
2625 logger = logging.getLogger(__name__)
2726
116115 except PartialDownloadError as pde:
117116 # Twisted is silly
118117 data = pde.response
119 resp_body = json.loads(data.decode("utf-8"))
118 resp_body = json_decoder.decode(data.decode("utf-8"))
120119
121120 if "success" in resp_body:
122121 # Note that we do NOT check the hostname here: we explicitly
233233 async def _handle_room_publicity_change(
234234 self, room_id, prev_event_id, event_id, typ
235235 ):
236 """Handle a room having potentially changed from/to world_readable/publically
236 """Handle a room having potentially changed from/to world_readable/publicly
237237 joinable.
238238
239239 Args:
387387
388388 prev_name = prev_event.content.get("displayname")
389389 new_name = event.content.get("displayname")
390 # If the new name is an unexpected form, do not update the directory.
391 if not isinstance(new_name, str):
392 new_name = prev_name
390393
391394 prev_avatar = prev_event.content.get("avatar_url")
392395 new_avatar = event.content.get("avatar_url")
396 # If the new avatar is an unexpected form, do not update the directory.
397 if not isinstance(new_avatar, str):
398 new_avatar = prev_avatar
393399
394400 if prev_name != new_name or prev_avatar != new_avatar:
395401 await self.store.update_profile_in_user_dir(user_id, new_name, new_avatar)
1818 from io import BytesIO
1919
2020 import treq
21 from canonicaljson import encode_canonical_json, json
21 from canonicaljson import encode_canonical_json
2222 from netaddr import IPAddress
2323 from prometheus_client import Counter
2424 from zope.interface import implementer, provider
4646 from synapse.http.proxyagent import ProxyAgent
4747 from synapse.logging.context import make_deferred_yieldable
4848 from synapse.logging.opentracing import set_tag, start_active_span, tags
49 from synapse.util import json_decoder
4950 from synapse.util.async_helpers import timeout_deferred
5051
5152 logger = logging.getLogger(__name__)
8485 return _scheduler
8586
8687
87 class IPBlacklistingResolver(object):
88 class IPBlacklistingResolver:
8889 """
8990 A proxy for reactor.nameResolver which only produces non-blacklisted IP
9091 addresses, preventing DNS rebinding attacks on URL preview.
131132 r.resolutionComplete()
132133
133134 @provider(IResolutionReceiver)
134 class EndpointReceiver(object):
135 class EndpointReceiver:
135136 @staticmethod
136137 def resolutionBegan(resolutionInProgress):
137138 pass
190191 )
191192
192193
193 class SimpleHttpClient(object):
194 class SimpleHttpClient:
194195 """
195196 A simple, no-frills HTTP client with methods that wrap up common ways of
196197 using HTTP in Matrix
242243 )
243244
244245 @implementer(IReactorPluggableNameResolver)
245 class Reactor(object):
246 class Reactor:
246247 def __getattr__(_self, attr):
247248 if attr == "nameResolver":
248249 return nameResolver
390391 body = await make_deferred_yieldable(readBody(response))
391392
392393 if 200 <= response.code < 300:
393 return json.loads(body.decode("utf-8"))
394 return json_decoder.decode(body.decode("utf-8"))
394395 else:
395396 raise HttpResponseException(
396397 response.code, response.phrase.decode("ascii", errors="replace"), body
432433 body = await make_deferred_yieldable(readBody(response))
433434
434435 if 200 <= response.code < 300:
435 return json.loads(body.decode("utf-8"))
436 return json_decoder.decode(body.decode("utf-8"))
436437 else:
437438 raise HttpResponseException(
438439 response.code, response.phrase.decode("ascii", errors="replace"), body
462463 actual_headers.update(headers)
463464
464465 body = await self.get_raw(uri, args, headers=headers)
465 return json.loads(body.decode("utf-8"))
466 return json_decoder.decode(body.decode("utf-8"))
466467
467468 async def put_json(self, uri, json_body, args={}, headers=None):
468469 """ Puts some json to the given URI.
505506 body = await make_deferred_yieldable(readBody(response))
506507
507508 if 200 <= response.code < 300:
508 return json.loads(body.decode("utf-8"))
509 return json_decoder.decode(body.decode("utf-8"))
509510 else:
510511 raise HttpResponseException(
511512 response.code, response.phrase.decode("ascii", errors="replace"), body
3030
3131
3232 @implementer(IStreamClientEndpoint)
33 class HTTPConnectProxyEndpoint(object):
33 class HTTPConnectProxyEndpoint:
3434 """An Endpoint implementation which will send a CONNECT request to an http proxy
3535
3636 Wraps an existing HostnameEndpoint for the proxy.
3535
3636
3737 @implementer(IAgent)
38 class MatrixFederationAgent(object):
38 class MatrixFederationAgent:
3939 """An Agent-like thing which provides a `request` method which correctly
4040 handles resolving matrix server names when using matrix://. Handles standard
4141 https URIs as normal.
133133 and not _is_ip_literal(parsed_uri.hostname)
134134 and not parsed_uri.port
135135 ):
136 well_known_result = yield self._well_known_resolver.get_well_known(
137 parsed_uri.hostname
136 well_known_result = yield defer.ensureDeferred(
137 self._well_known_resolver.get_well_known(parsed_uri.hostname)
138138 )
139139 delegated_server = well_known_result.delegated_server
140140
174174
175175
176176 @implementer(IAgentEndpointFactory)
177 class MatrixHostnameEndpointFactory(object):
177 class MatrixHostnameEndpointFactory:
178178 """Factory for MatrixHostnameEndpoint for parsing to an Agent.
179179 """
180180
197197
198198
199199 @implementer(IStreamClientEndpoint)
200 class MatrixHostnameEndpoint(object):
200 class MatrixHostnameEndpoint:
201201 """An endpoint that resolves matrix:// URLs using Matrix server name
202202 resolution (i.e. via SRV). Does not check for well-known delegation.
203203
3232
3333
3434 @attr.s(slots=True, frozen=True)
35 class Server(object):
35 class Server:
3636 """
3737 Our record of an individual server which can be tried to reach a destination.
3838
9595 return results
9696
9797
98 class SrvResolver(object):
98 class SrvResolver:
9999 """Interface to the dns client to do SRV lookups, with result caching.
100100
101101 The default resolver in twisted.names doesn't do any caching (it has a CacheResolver,
1212 # See the License for the specific language governing permissions and
1313 # limitations under the License.
1414
15 import json
1615 import logging
1716 import random
1817 import time
18 from typing import Callable, Dict, Optional, Tuple
1919
2020 import attr
2121
2323 from twisted.web.client import RedirectAgent, readBody
2424 from twisted.web.http import stringToDatetime
2525 from twisted.web.http_headers import Headers
26 from twisted.web.iweb import IResponse
2627
2728 from synapse.logging.context import make_deferred_yieldable
28 from synapse.util import Clock
29 from synapse.util import Clock, json_decoder
2930 from synapse.util.caches.ttlcache import TTLCache
3031 from synapse.util.metrics import Measure
3132
6970
7071
7172 @attr.s(slots=True, frozen=True)
72 class WellKnownLookupResult(object):
73 class WellKnownLookupResult:
7374 delegated_server = attr.ib()
7475
7576
76 class WellKnownResolver(object):
77 class WellKnownResolver:
7778 """Handles well-known lookups for matrix servers.
7879 """
7980
99100 self._well_known_agent = RedirectAgent(agent)
100101 self.user_agent = user_agent
101102
102 @defer.inlineCallbacks
103 def get_well_known(self, server_name):
103 async def get_well_known(self, server_name: bytes) -> WellKnownLookupResult:
104104 """Attempt to fetch and parse a .well-known file for the given server
105105
106106 Args:
107 server_name (bytes): name of the server, from the requested url
107 server_name: name of the server, from the requested url
108108
109109 Returns:
110 Deferred[WellKnownLookupResult]: The result of the lookup
110 The result of the lookup
111111 """
112112 try:
113113 prev_result, expiry, ttl = self._well_known_cache.get_with_expiry(
124124 # requests for the same server in parallel?
125125 try:
126126 with Measure(self._clock, "get_well_known"):
127 result, cache_period = yield self._fetch_well_known(server_name)
127 result, cache_period = await self._fetch_well_known(
128 server_name
129 ) # type: Tuple[Optional[bytes], float]
128130
129131 except _FetchWellKnownFailure as e:
130132 if prev_result and e.temporary:
153155
154156 return WellKnownLookupResult(delegated_server=result)
155157
156 @defer.inlineCallbacks
157 def _fetch_well_known(self, server_name):
158 async def _fetch_well_known(self, server_name: bytes) -> Tuple[bytes, float]:
158159 """Actually fetch and parse a .well-known, without checking the cache
159160
160161 Args:
161 server_name (bytes): name of the server, from the requested url
162 server_name: name of the server, from the requested url
162163
163164 Raises:
164165 _FetchWellKnownFailure if we fail to lookup a result
165166
166167 Returns:
167 Deferred[Tuple[bytes,int]]: The lookup result and cache period.
168 The lookup result and cache period.
168169 """
169170
170171 had_valid_well_known = self._had_valid_well_known_cache.get(server_name, False)
172173 # We do this in two steps to differentiate between possibly transient
173174 # errors (e.g. can't connect to host, 503 response) and more permenant
174175 # errors (such as getting a 404 response).
175 response, body = yield self._make_well_known_request(
176 response, body = await self._make_well_known_request(
176177 server_name, retry=had_valid_well_known
177178 )
178179
180181 if response.code != 200:
181182 raise Exception("Non-200 response %s" % (response.code,))
182183
183 parsed_body = json.loads(body.decode("utf-8"))
184 parsed_body = json_decoder.decode(body.decode("utf-8"))
184185 logger.info("Response from .well-known: %s", parsed_body)
185186
186187 result = parsed_body["m.server"].encode("ascii")
215216
216217 return result, cache_period
217218
218 @defer.inlineCallbacks
219 def _make_well_known_request(self, server_name, retry):
219 async def _make_well_known_request(
220 self, server_name: bytes, retry: bool
221 ) -> Tuple[IResponse, bytes]:
220222 """Make the well known request.
221223
222224 This will retry the request if requested and it fails (with unable
223225 to connect or receives a 5xx error).
224226
225227 Args:
226 server_name (bytes)
227 retry (bool): Whether to retry the request if it fails.
228 server_name: name of the server, from the requested url
229 retry: Whether to retry the request if it fails.
228230
229231 Returns:
230 Deferred[tuple[IResponse, bytes]] Returns the response object and
231 body. Response may be a non-200 response.
232 Returns the response object and body. Response may be a non-200 response.
232233 """
233234 uri = b"https://%s/.well-known/matrix/server" % (server_name,)
234235 uri_str = uri.decode("ascii")
243244
244245 logger.info("Fetching %s", uri_str)
245246 try:
246 response = yield make_deferred_yieldable(
247 response = await make_deferred_yieldable(
247248 self._well_known_agent.request(
248249 b"GET", uri, headers=Headers(headers)
249250 )
250251 )
251 body = yield make_deferred_yieldable(readBody(response))
252 body = await make_deferred_yieldable(readBody(response))
252253
253254 if 500 <= response.code < 600:
254255 raise Exception("Non-200 response %s" % (response.code,))
265266 logger.info("Error fetching %s: %s. Retrying", uri_str, e)
266267
267268 # Sleep briefly in the hopes that they come back up
268 yield self._clock.sleep(0.5)
269
270
271 def _cache_period_from_headers(headers, time_now=time.time):
269 await self._clock.sleep(0.5)
270
271
272 def _cache_period_from_headers(
273 headers: Headers, time_now: Callable[[], float] = time.time
274 ) -> Optional[float]:
272275 cache_controls = _parse_cache_control(headers)
273276
274277 if b"no-store" in cache_controls:
275278 return 0
276279
277280 if b"max-age" in cache_controls:
278 try:
279 max_age = int(cache_controls[b"max-age"])
280 return max_age
281 except ValueError:
282 pass
281 max_age = cache_controls[b"max-age"]
282 if max_age:
283 try:
284 return int(max_age)
285 except ValueError:
286 pass
283287
284288 expires = headers.getRawHeaders(b"expires")
285289 if expires is not None:
295299 return None
296300
297301
298 def _parse_cache_control(headers):
302 def _parse_cache_control(headers: Headers) -> Dict[bytes, Optional[bytes]]:
299303 cache_controls = {}
300304 for hdr in headers.getRawHeaders(b"cache-control", []):
301305 for directive in hdr.split(b","):
5353 start_active_span,
5454 tags,
5555 )
56 from synapse.util import json_decoder
5657 from synapse.util.async_helpers import timeout_deferred
5758 from synapse.util.metrics import Measure
5859
7576
7677
7778 @attr.s(frozen=True)
78 class MatrixFederationRequest(object):
79 class MatrixFederationRequest:
7980 method = attr.ib()
8081 """HTTP method
8182 :type: str
163164 try:
164165 check_content_type_is_json(response.headers)
165166
166 d = treq.json_content(response)
167 # Use the custom JSON decoder (partially re-implements treq.json_content).
168 d = treq.text_content(response, encoding="utf-8")
169 d.addCallback(json_decoder.decode)
167170 d = timeout_deferred(d, timeout=timeout_sec, reactor=reactor)
168171
169172 body = await make_deferred_yieldable(d)
202205 return body
203206
204207
205 class MatrixFederationHttpClient(object):
208 class MatrixFederationHttpClient:
206209 """HTTP client used to talk to other homeservers over the federation
207210 protocol. Send client certificates and signs requests.
208211
225228 )
226229
227230 @implementer(IReactorPluggableNameResolver)
228 class Reactor(object):
231 class Reactor:
229232 def __getattr__(_self, attr):
230233 if attr == "nameResolver":
231234 return nameResolver
144144 )
145145
146146
147 class RequestMetrics(object):
147 class RequestMetrics:
148148 def start(self, time_sec, name, method):
149149 self.start = time_sec
150150 self.start_context = current_context()
2121 import urllib
2222 from http import HTTPStatus
2323 from io import BytesIO
24 from typing import Any, Callable, Dict, Tuple, Union
24 from typing import Any, Callable, Dict, Iterator, List, Tuple, Union
2525
2626 import jinja2
27 from canonicaljson import encode_canonical_json, encode_pretty_printed_json
28
29 from twisted.internet import defer
27 from canonicaljson import iterencode_canonical_json, iterencode_pretty_printed_json
28 from zope.interface import implementer
29
30 from twisted.internet import defer, interfaces
3031 from twisted.python import failure
3132 from twisted.web import resource
3233 from twisted.web.server import NOT_DONE_YET, Request
172173 return preserve_fn(wrapped_async_request_handler)
173174
174175
175 class HttpServer(object):
176 class HttpServer:
176177 """ Interface for registering callbacks on a HTTP server
177178 """
178179
496497
497498 class RootOptionsRedirectResource(OptionsResource, RootRedirect):
498499 pass
500
501
502 @implementer(interfaces.IPushProducer)
503 class _ByteProducer:
504 """
505 Iteratively write bytes to the request.
506 """
507
508 # The minimum number of bytes for each chunk. Note that the last chunk will
509 # usually be smaller than this.
510 min_chunk_size = 1024
511
512 def __init__(
513 self, request: Request, iterator: Iterator[bytes],
514 ):
515 self._request = request
516 self._iterator = iterator
517 self._paused = False
518
519 # Register the producer and start producing data.
520 self._request.registerProducer(self, True)
521 self.resumeProducing()
522
523 def _send_data(self, data: List[bytes]) -> None:
524 """
525 Send a list of bytes as a chunk of a response.
526 """
527 if not data:
528 return
529 self._request.write(b"".join(data))
530
531 def pauseProducing(self) -> None:
532 self._paused = True
533
534 def resumeProducing(self) -> None:
535 # We've stopped producing in the meantime (note that this might be
536 # re-entrant after calling write).
537 if not self._request:
538 return
539
540 self._paused = False
541
542 # Write until there's backpressure telling us to stop.
543 while not self._paused:
544 # Get the next chunk and write it to the request.
545 #
546 # The output of the JSON encoder is buffered and coalesced until
547 # min_chunk_size is reached. This is because JSON encoders produce
548 # very small output per iteration and the Request object converts
549 # each call to write() to a separate chunk. Without this there would
550 # be an explosion in bytes written (e.g. b"{" becoming "1\r\n{\r\n").
551 #
552 # Note that buffer stores a list of bytes (instead of appending to
553 # bytes) to hopefully avoid many allocations.
554 buffer = []
555 buffered_bytes = 0
556 while buffered_bytes < self.min_chunk_size:
557 try:
558 data = next(self._iterator)
559 buffer.append(data)
560 buffered_bytes += len(data)
561 except StopIteration:
562 # The entire JSON object has been serialized, write any
563 # remaining data, finalize the producer and the request, and
564 # clean-up any references.
565 self._send_data(buffer)
566 self._request.unregisterProducer()
567 self._request.finish()
568 self.stopProducing()
569 return
570
571 self._send_data(buffer)
572
573 def stopProducing(self) -> None:
574 # Clear a circular reference.
575 self._request = None
576
577
578 def _encode_json_bytes(json_object: Any) -> Iterator[bytes]:
579 """
580 Encode an object into JSON. Returns an iterator of bytes.
581 """
582 for chunk in json_encoder.iterencode(json_object):
583 yield chunk.encode("utf-8")
499584
500585
501586 def respond_with_json(
532617 return None
533618
534619 if pretty_print:
535 json_bytes = encode_pretty_printed_json(json_object) + b"\n"
620 encoder = iterencode_pretty_printed_json
536621 else:
537622 if canonical_json or synapse.events.USE_FROZEN_DICTS:
538 # canonicaljson already encodes to bytes
539 json_bytes = encode_canonical_json(json_object)
623 encoder = iterencode_canonical_json
540624 else:
541 json_bytes = json_encoder.encode(json_object).encode("utf-8")
542
543 return respond_with_json_bytes(request, code, json_bytes, send_cors=send_cors)
625 encoder = _encode_json_bytes
626
627 request.setResponseCode(code)
628 request.setHeader(b"Content-Type", b"application/json")
629 request.setHeader(b"Cache-Control", b"no-cache, no-store, must-revalidate")
630
631 if send_cors:
632 set_cors_headers(request)
633
634 _ByteProducer(request, encoder(json_object))
635 return NOT_DONE_YET
544636
545637
546638 def respond_with_json_bytes(
1616
1717 import logging
1818
19 from canonicaljson import json
20
2119 from synapse.api.errors import Codes, SynapseError
20 from synapse.util import json_decoder
2221
2322 logger = logging.getLogger(__name__)
2423
214213 return None
215214
216215 try:
217 content = json.loads(content_bytes.decode("utf-8"))
216 content = json_decoder.decode(content_bytes.decode("utf-8"))
218217 except Exception as e:
219218 logger.warning("Unable to parse JSON: %s", e)
220219 raise SynapseError(400, "Content not JSON.", errcode=Codes.NOT_JSON)
256255 raise SynapseError(400, "Missing params: %r" % absent, Codes.MISSING_PARAM)
257256
258257
259 class RestServlet(object):
258 class RestServlet:
260259
261260 """ A Synapse REST Servlet.
262261
5454
5555 @attr.s
5656 @implementer(ILogObserver)
57 class LogContextObserver(object):
57 class LogContextObserver:
5858 """
5959 An ILogObserver which adds Synapse-specific log context information.
6060
168168
169169
170170 @attr.s
171 class DrainConfiguration(object):
171 class DrainConfiguration:
172172 name = attr.ib()
173173 type = attr.ib()
174174 location = attr.ib()
176176
177177
178178 @attr.s
179 class NetworkJSONTerseOptions(object):
179 class NetworkJSONTerseOptions:
180180 maximum_buffer = attr.ib(type=int)
181181
182182
151151
152152 @attr.s
153153 @implementer(IPushProducer)
154 class LogProducer(object):
154 class LogProducer:
155155 """
156156 An IPushProducer that writes logs from its buffer to its transport when it
157157 is resumed.
189189
190190 @attr.s
191191 @implementer(ILogObserver)
192 class TerseJSONToTCPLogObserver(object):
192 class TerseJSONToTCPLogObserver:
193193 """
194194 An IObserver that writes JSON logs to a TCP target.
195195
7373 get_thread_id = threading.get_ident
7474
7575
76 class ContextResourceUsage(object):
76 class ContextResourceUsage:
7777 """Object for tracking the resources used by a log context
7878
7979 Attributes:
178178 LoggingContextOrSentinel = Union["LoggingContext", "_Sentinel"]
179179
180180
181 class _Sentinel(object):
181 class _Sentinel:
182182 """Sentinel to represent the root context"""
183183
184184 __slots__ = ["previous_context", "finished", "request", "scope", "tag"]
225225 SENTINEL_CONTEXT = _Sentinel()
226226
227227
228 class LoggingContext(object):
228 class LoggingContext:
229229 """Additional context for log formatting. Contexts are scoped within a
230230 "with" block.
231231
171171 from typing import TYPE_CHECKING, Dict, Optional, Type
172172
173173 import attr
174 from canonicaljson import json
175174
176175 from twisted.internet import defer
177176
178177 from synapse.config import ConfigError
178 from synapse.util import json_decoder, json_encoder
179179
180180 if TYPE_CHECKING:
181181 from synapse.http.site import SynapseRequest
184184 # Helper class
185185
186186
187 class _DummyTagNames(object):
187 class _DummyTagNames:
188188 """wrapper of opentracings tags. We need to have them if we
189189 want to reference them without opentracing around. Clearly they
190190 should never actually show up in a trace. `set_tags` overwrites
498498 if opentracing is None:
499499 return _noop_context_manager()
500500
501 carrier = json.loads(edu_content.get("context", "{}")).get("opentracing", {})
501 carrier = json_decoder.decode(edu_content.get("context", "{}")).get(
502 "opentracing", {}
503 )
502504 context = opentracing.tracer.extract(opentracing.Format.TEXT_MAP, carrier)
503505 _references = [
504506 opentracing.child_of(span_context_from_string(x))
689691 opentracing.tracer.inject(
690692 opentracing.tracer.active_span, opentracing.Format.TEXT_MAP, carrier
691693 )
692 return json.dumps(carrier)
694 return json_encoder.encode(carrier)
693695
694696
695697 @only_if_tracing
698700 Returns:
699701 The active span context decoded from a string.
700702 """
701 carrier = json.loads(carrier)
703 carrier = json_decoder.decode(carrier)
702704 return opentracing.tracer.extract(opentracing.Format.TEXT_MAP, carrier)
703705
704706
5050 HAVE_PROC_SELF_STAT = os.path.exists("/proc/self/stat")
5151
5252
53 class RegistryProxy(object):
53 class RegistryProxy:
5454 @staticmethod
5555 def collect():
5656 for metric in REGISTRY.collect():
5959
6060
6161 @attr.s(hash=True)
62 class LaterGauge(object):
62 class LaterGauge:
6363
6464 name = attr.ib(type=str)
6565 desc = attr.ib(type=str)
9999 all_gauges[self.name] = self
100100
101101
102 class InFlightGauge(object):
102 class InFlightGauge:
103103 """Tracks number of things (e.g. requests, Measure blocks, etc) in flight
104104 at any given time.
105105
205205
206206
207207 @attr.s(hash=True)
208 class BucketCollector(object):
208 class BucketCollector:
209209 """
210210 Like a Histogram, but allows buckets to be point-in-time instead of
211211 incrementally added to.
268268 #
269269
270270
271 class CPUMetrics(object):
271 class CPUMetrics:
272272 def __init__(self):
273273 ticks_per_sec = 100
274274 try:
328328 )
329329
330330
331 class GCCounts(object):
331 class GCCounts:
332332 def collect(self):
333333 cm = GaugeMetricFamily("python_gc_counts", "GC object counts", labels=["gen"])
334334 for n, m in enumerate(gc.get_count()):
346346 #
347347
348348
349 class PyPyGCStats(object):
349 class PyPyGCStats:
350350 def collect(self):
351351
352352 # @stats is a pretty-printer object with __str__() returning a nice table,
481481 last_ticked = time.time()
482482
483483
484 class ReactorLastSeenMetric(object):
484 class ReactorLastSeenMetric:
485485 def collect(self):
486486 cm = GaugeMetricFamily(
487487 "python_twisted_reactor_last_seen",
104104 _bg_metrics_lock = threading.Lock()
105105
106106
107 class _Collector(object):
107 class _Collector:
108108 """A custom metrics collector for the background process metrics.
109109
110110 Ensures that all of the metrics are up-to-date with any in-flight processes
139139 REGISTRY.register(_Collector())
140140
141141
142 class _BackgroundProcess(object):
142 class _BackgroundProcess:
143143 def __init__(self, desc, ctx):
144144 self.desc = desc
145145 self._context = ctx
174174 It returns a Deferred which completes when the function completes, but it doesn't
175175 follow the synapse logcontext rules, which makes it appropriate for passing to
176176 clock.looping_call and friends (or for firing-and-forgetting in the middle of a
177 normal synapse inlineCallbacks function).
177 normal synapse async function).
178178
179179 Args:
180180 desc: a description for this background process type
3030 logger = logging.getLogger(__name__)
3131
3232
33 class ModuleApi(object):
33 class ModuleApi:
3434 """A proxy object that gets passed to various plugin modules so they
3535 can register new users etc if necessary.
3636 """
166166 external_id: id on that system
167167 user_id: complete mxid that it is mapped to
168168 """
169 return self._store.record_user_external_id(
170 auth_provider_id, remote_user_id, registered_user_id
169 return defer.ensureDeferred(
170 self._store.record_user_external_id(
171 auth_provider_id, remote_user_id, registered_user_id
172 )
171173 )
172174
173175 def generate_short_term_login_token(
222224 Returns:
223225 Deferred[object]: result of func
224226 """
225 return self._store.db_pool.runInteraction(desc, func, *args, **kwargs)
227 return defer.ensureDeferred(
228 self._store.db_pool.runInteraction(desc, func, *args, **kwargs)
229 )
226230
227231 def complete_sso_login(
228232 self, registered_user_id: str, request: SynapseRequest, client_redirect_url: str
6767 return n
6868
6969
70 class _NotificationListener(object):
70 class _NotificationListener:
7171 """ This represents a single client connection to the events stream.
7272 The events stream handler will have yielded to the deferred, so to
7373 notify the handler it is sufficient to resolve the deferred.
7979 self.deferred = deferred
8080
8181
82 class _NotifierUserStream(object):
82 class _NotifierUserStream:
8383 """This represents a user connected to the event stream.
8484 It tracks the most recent stream token for that user.
8585 At a given point a user may have a number of streams listening for
167167 __bool__ = __nonzero__ # python3
168168
169169
170 class Notifier(object):
170 class Notifier:
171171 """ This class is responsible for notifying any listeners when there are
172172 new events available for it.
173173
2121 logger = logging.getLogger(__name__)
2222
2323
24 class ActionGenerator(object):
24 class ActionGenerator:
2525 def __init__(self, hs):
2626 self.hs = hs
2727 self.clock = hs.get_clock()
1818
1919 from prometheus_client import Counter
2020
21 from synapse.api.constants import EventTypes, Membership
21 from synapse.api.constants import EventTypes, Membership, RelationTypes
2222 from synapse.event_auth import get_user_power_level
23 from synapse.events import EventBase
24 from synapse.events.snapshot import EventContext
2325 from synapse.state import POWER_KEY
2426 from synapse.util.async_helpers import Linearizer
2527 from synapse.util.caches import register_cache
5052 )
5153
5254
53 class BulkPushRuleEvaluator(object):
55 STATE_EVENT_TYPES_TO_MARK_UNREAD = {
56 EventTypes.Topic,
57 EventTypes.Name,
58 EventTypes.RoomAvatar,
59 EventTypes.Tombstone,
60 }
61
62
63 def _should_count_as_unread(event: EventBase, context: EventContext) -> bool:
64 # Exclude rejected and soft-failed events.
65 if context.rejected or event.internal_metadata.is_soft_failed():
66 return False
67
68 # Exclude notices.
69 if (
70 not event.is_state()
71 and event.type == EventTypes.Message
72 and event.content.get("msgtype") == "m.notice"
73 ):
74 return False
75
76 # Exclude edits.
77 relates_to = event.content.get("m.relates_to", {})
78 if relates_to.get("rel_type") == RelationTypes.REPLACE:
79 return False
80
81 # Mark events that have a non-empty string body as unread.
82 body = event.content.get("body")
83 if isinstance(body, str) and body:
84 return True
85
86 # Mark some state events as unread.
87 if event.is_state() and event.type in STATE_EVENT_TYPES_TO_MARK_UNREAD:
88 return True
89
90 # Mark encrypted events as unread.
91 if not event.is_state() and event.type == EventTypes.Encrypted:
92 return True
93
94 return False
95
96
97 class BulkPushRuleEvaluator:
5498 """Calculates the outcome of push rules for an event for all users in the
5599 room at once.
56100 """
132176 return pl_event.content if pl_event else {}, sender_level
133177
134178 async def action_for_event_by_user(self, event, context) -> None:
135 """Given an event and context, evaluate the push rules and insert the
136 results into the event_push_actions_staging table.
137 """
179 """Given an event and context, evaluate the push rules, check if the message
180 should increment the unread count, and insert the results into the
181 event_push_actions_staging table.
182 """
183 count_as_unread = _should_count_as_unread(event, context)
184
138185 rules_by_user = await self._get_rules_for_event(event, context)
139186 actions_by_user = {}
140187
170217 # that user, as they might not be already joined.
171218 if event.type == EventTypes.Member and event.state_key == uid:
172219 display_name = event.content.get("displayname", None)
220
221 if count_as_unread:
222 # Add an element for the current user if the event needs to be marked as
223 # unread, so that add_push_actions_to_staging iterates over it.
224 # If the event shouldn't be marked as unread but should notify the
225 # current user, it'll be added to the dict later.
226 actions_by_user[uid] = []
173227
174228 for rule in rules:
175229 if "enabled" in rule and not rule["enabled"]:
188242 # Mark in the DB staging area the push actions for users who should be
189243 # notified for this event. (This will then get handled when we persist
190244 # the event)
191 await self.store.add_push_actions_to_staging(event.event_id, actions_by_user)
245 await self.store.add_push_actions_to_staging(
246 event.event_id, actions_by_user, count_as_unread,
247 )
192248
193249
194250 def _condition_checker(evaluator, conditions, uid, display_name, cache):
211267 return True
212268
213269
214 class RulesForRoom(object):
270 class RulesForRoom:
215271 """Caches push rules for users in a room.
216272
217273 This efficiently handles users joining/leaving the room by not invalidating
368424 Args:
369425 ret_rules_by_user (dict): Partiallly filled dict of push rules. Gets
370426 updated with any new rules.
371 member_event_ids (list): List of event ids for membership events that
372 have happened since the last time we filled rules_by_user
427 member_event_ids (dict): Dict of user id to event id for membership events
428 that have happened since the last time we filled rules_by_user
373429 state_group: The state group we are currently computing push rules
374430 for. Used when updating the cache.
375431 """
389445 if logger.isEnabledFor(logging.DEBUG):
390446 logger.debug("Found members %r: %r", self.room_id, members.values())
391447
392 interested_in_user_ids = {
448 user_ids = {
393449 user_id
394450 for user_id, membership in members.values()
395451 if membership == Membership.JOIN
396452 }
397453
398 logger.debug("Joined: %r", interested_in_user_ids)
399
400 if_users_with_pushers = await self.store.get_if_users_have_pushers(
401 interested_in_user_ids, on_invalidate=self.invalidate_all_cb
402 )
403
404 user_ids = {
405 uid for uid, have_pusher in if_users_with_pushers.items() if have_pusher
406 }
407
408 logger.debug("With pushers: %r", user_ids)
409
410 users_with_receipts = await self.store.get_users_with_read_receipts_in_room(
411 self.room_id, on_invalidate=self.invalidate_all_cb
412 )
413
414 logger.debug("With receipts: %r", users_with_receipts)
415
416 # any users with pushers must be ours: they have pushers
417 for uid in users_with_receipts:
418 if uid in interested_in_user_ids:
419 user_ids.add(uid)
454 logger.debug("Joined: %r", user_ids)
455
456 # Previously we only considered users with pushers or read receipts in that
457 # room. We can't do this anymore because we use push actions to calculate unread
458 # counts, which don't rely on the user having pushers or sent a read receipt into
459 # the room. Therefore we just need to filter for local users here.
460 user_ids = list(filter(self.is_mine_id, user_ids))
420461
421462 rules_by_user = await self.store.bulk_get_push_rules(
422463 user_ids, on_invalidate=self.invalidate_all_cb
4444 INCLUDE_ALL_UNREAD_NOTIFS = False
4545
4646
47 class EmailPusher(object):
47 class EmailPusher:
4848 """
4949 A pusher that sends email notifications about events (approximately)
5050 when they happen.
4848 )
4949
5050
51 class HttpPusher(object):
51 class HttpPusher:
5252 INITIAL_BACKOFF_SEC = 1 # in seconds because that's what Twisted takes
5353 MAX_BACKOFF_SEC = 60 * 60
5454
1515 import email.mime.multipart
1616 import email.utils
1717 import logging
18 import time
19 import urllib
18 import urllib.parse
2019 from email.mime.multipart import MIMEMultipart
2120 from email.mime.text import MIMEText
2221 from typing import Iterable, List, TypeVar
9291 # ALLOWED_SCHEMES = ["http", "https", "ftp", "mailto"]
9392
9493
95 class Mailer(object):
94 class Mailer:
9695 def __init__(self, hs, app_name, template_html, template_text):
9796 self.hs = hs
9897 self.template_html = template_html
639638 for c in s:
640639 tot += ord(c)
641640 return tot
642
643
644 def format_ts_filter(value, format):
645 return time.strftime(format, time.localtime(value / 1000))
646
647
648 def load_jinja2_templates(
649 template_dir,
650 template_filenames,
651 apply_format_ts_filter=False,
652 apply_mxc_to_http_filter=False,
653 public_baseurl=None,
654 ):
655 """Loads and returns one or more jinja2 templates and applies optional filters
656
657 Args:
658 template_dir (str): The directory where templates are stored
659 template_filenames (list[str]): A list of template filenames
660 apply_format_ts_filter (bool): Whether to apply a template filter that formats
661 timestamps
662 apply_mxc_to_http_filter (bool): Whether to apply a template filter that converts
663 mxc urls to http urls
664 public_baseurl (str|None): The public baseurl of the server. Required for
665 apply_mxc_to_http_filter to be enabled
666
667 Returns:
668 A list of jinja2 templates corresponding to the given list of filenames,
669 with order preserved
670 """
671 logger.info(
672 "loading email templates %s from '%s'", template_filenames, template_dir
673 )
674 loader = jinja2.FileSystemLoader(template_dir)
675 env = jinja2.Environment(loader=loader)
676
677 if apply_format_ts_filter:
678 env.filters["format_ts"] = format_ts_filter
679
680 if apply_mxc_to_http_filter and public_baseurl:
681 env.filters["mxc_to_http"] = _create_mxc_to_http_filter(public_baseurl)
682
683 templates = []
684 for template_filename in template_filenames:
685 template = env.get_template(template_filename)
686 templates.append(template)
687
688 return templates
689
690
691 def _create_mxc_to_http_filter(public_baseurl):
692 def mxc_to_http_filter(value, width, height, resize_method="crop"):
693 if value[0:6] != "mxc://":
694 return ""
695
696 serverAndMediaId = value[6:]
697 fragment = None
698 if "#" in serverAndMediaId:
699 (serverAndMediaId, fragment) = serverAndMediaId.split("#", 1)
700 fragment = "#" + fragment
701
702 params = {"width": width, "height": height, "method": resize_method}
703 return "%s_matrix/media/v1/thumbnail/%s?%s%s" % (
704 public_baseurl,
705 serverAndMediaId,
706 urllib.parse.urlencode(params),
707 fragment or "",
708 )
709
710 return mxc_to_http_filter
104104 return tweaks
105105
106106
107 class PushRuleEvaluatorForEvent(object):
107 class PushRuleEvaluatorForEvent:
108108 def __init__(
109109 self,
110110 event: EventBase,
1414
1515 import logging
1616
17 from synapse.push.emailpusher import EmailPusher
18 from synapse.push.mailer import Mailer
19
1720 from .httppusher import HttpPusher
1821
1922 logger = logging.getLogger(__name__)
2023
21 # We try importing this if we can (it will fail if we don't
22 # have the optional email dependencies installed). We don't
23 # yet have the config to know if we need the email pusher,
24 # but importing this after daemonizing seems to fail
25 # (even though a simple test of importing from a daemonized
26 # process works fine)
27 try:
28 from synapse.push.emailpusher import EmailPusher
29 from synapse.push.mailer import Mailer, load_jinja2_templates
30 except Exception:
31 pass
3224
33
34 class PusherFactory(object):
25 class PusherFactory:
3526 def __init__(self, hs):
3627 self.hs = hs
3728 self.config = hs.config
4233 if hs.config.email_enable_notifs:
4334 self.mailers = {} # app_name -> Mailer
4435
45 self.notif_template_html, self.notif_template_text = load_jinja2_templates(
46 self.config.email_template_dir,
47 [
48 self.config.email_notif_template_html,
49 self.config.email_notif_template_text,
50 ],
51 apply_format_ts_filter=True,
52 apply_mxc_to_http_filter=True,
53 public_baseurl=self.config.public_baseurl,
54 )
36 self._notif_template_html = hs.config.email_notif_template_html
37 self._notif_template_text = hs.config.email_notif_template_text
5538
5639 self.pusher_types["email"] = self._create_email_pusher
5740
7255 mailer = Mailer(
7356 hs=self.hs,
7457 app_name=app_name,
75 template_html=self.notif_template_html,
76 template_text=self.notif_template_text,
58 template_html=self._notif_template_html,
59 template_text=self._notif_template_text,
7760 )
7861 self.mailers[app_name] = mailer
7962 return EmailPusher(self.hs, pusherdict, mailer)
4242 "jsonschema>=2.5.1",
4343 "frozendict>=1",
4444 "unpaddedbase64>=1.1.0",
45 "canonicaljson>=1.2.0",
45 "canonicaljson>=1.3.0",
4646 # we use the type definitions added in signedjson 1.1.
4747 "signedjson>=1.1.0",
4848 "pynacl>=1.2.1",
6565 "msgpack>=0.5.2",
6666 "phonenumbers>=8.2.0",
6767 "prometheus_client>=0.0.18,<0.9.0",
68 # we use attr.validators.deep_iterable, which arrived in 19.1.0
68 # we use attr.validators.deep_iterable, which arrived in 19.1.0 (Note:
69 # Fedora 31 only has 19.1, so if we want to upgrade we should wait until 33
70 # is out in November.)
6971 "attrs>=19.1.0",
7072 "netaddr>=0.7.18",
7173 "Jinja2>=2.9",
7779 "matrix-synapse-ldap3": ["matrix-synapse-ldap3>=0.1"],
7880 # we use execute_batch, which arrived in psycopg 2.7.
7981 "postgres": ["psycopg2>=2.7"],
80 # ConsentResource uses select_autoescape, which arrived in jinja 2.9
81 "resources.consent": ["Jinja2>=2.9"],
8282 # ACME support is required to provision TLS certificates from authorities
8383 # that use the protocol, such as Let's Encrypt.
8484 "acme": [
3232 logger = logging.getLogger(__name__)
3333
3434
35 class ReplicationEndpoint(object):
35 class ReplicationEndpoint:
3636 """Helper base class for defining new replication HTTP endpoints.
3737
3838 This creates an endpoint under `/_synapse/replication/:NAME/:PATH_ARGS..`
4343 admin,
4444 user_type,
4545 address,
46 shadow_banned,
4647 ):
4748 """
4849 Args:
5960 user_type (str|None): type of user. One of the values from
6061 api.constants.UserTypes, or None for a normal user.
6162 address (str|None): the IP address used to perform the regitration.
63 shadow_banned (bool): Whether to shadow-ban the user
6264 """
6365 return {
6466 "password_hash": password_hash,
6971 "admin": admin,
7072 "user_type": user_type,
7173 "address": address,
74 "shadow_banned": shadow_banned,
7275 }
7376
7477 async def _handle_request(self, request, user_id):
8689 admin=content["admin"],
8790 user_type=content["user_type"],
8891 address=content["address"],
92 shadow_banned=content["shadow_banned"],
8993 )
9094
9195 return 200, {}
1515 from synapse.storage.util.id_generators import _load_current_id
1616
1717
18 class SlavedIdTracker(object):
18 class SlavedIdTracker:
1919 def __init__(self, db_conn, table, column, extra_tables=[], step=1):
2020 self.step = step
2121 self._current = _load_current_id(db_conn, table, column, step)
2222 for table, column in extra_tables:
23 self.advance(_load_current_id(db_conn, table, column))
23 self.advance(None, _load_current_id(db_conn, table, column))
2424
25 def advance(self, new_id):
25 def advance(self, instance_name, new_id):
2626 self._current = (max if self.step > 0 else min)(self._current, new_id)
2727
2828 def get_current_token(self):
3232 int
3333 """
3434 return self._current
35
36 def get_current_token_for_writer(self, instance_name: str) -> int:
37 """Returns the position of the given writer.
38
39 For streams with single writers this is equivalent to
40 `get_current_token`.
41 """
42 return self.get_current_token()
4040
4141 def process_replication_rows(self, stream_name, instance_name, token, rows):
4242 if stream_name == TagAccountDataStream.NAME:
43 self._account_data_id_gen.advance(token)
43 self._account_data_id_gen.advance(instance_name, token)
4444 for row in rows:
4545 self.get_tags_for_user.invalidate((row.user_id,))
4646 self._account_data_stream_cache.entity_has_changed(row.user_id, token)
4747 elif stream_name == AccountDataStream.NAME:
48 self._account_data_id_gen.advance(token)
48 self._account_data_id_gen.advance(instance_name, token)
4949 for row in rows:
5050 if not row.room_id:
5151 self.get_global_account_data_by_type_for_user.invalidate(
4545
4646 def process_replication_rows(self, stream_name, instance_name, token, rows):
4747 if stream_name == ToDeviceStream.NAME:
48 self._device_inbox_id_gen.advance(token)
48 self._device_inbox_id_gen.advance(instance_name, token)
4949 for row in rows:
5050 if row.entity.startswith("@"):
5151 self._device_inbox_stream_cache.entity_has_changed(
4747 "DeviceListFederationStreamChangeCache", device_list_max
4848 )
4949
50 def get_device_stream_token(self) -> int:
51 return self._device_list_id_gen.get_current_token()
52
5053 def process_replication_rows(self, stream_name, instance_name, token, rows):
5154 if stream_name == DeviceListsStream.NAME:
52 self._device_list_id_gen.advance(token)
55 self._device_list_id_gen.advance(instance_name, token)
5356 self._invalidate_caches_for_devices(token, rows)
5457 elif stream_name == UserSignatureStream.NAME:
55 self._device_list_id_gen.advance(token)
58 self._device_list_id_gen.advance(instance_name, token)
5659 for row in rows:
5760 self._user_signature_stream_cache.entity_has_changed(row.user_id, token)
5861 return super().process_replication_rows(stream_name, instance_name, token, rows)
3939
4040 def process_replication_rows(self, stream_name, instance_name, token, rows):
4141 if stream_name == GroupServerStream.NAME:
42 self._group_updates_id_gen.advance(token)
42 self._group_updates_id_gen.advance(instance_name, token)
4343 for row in rows:
4444 self._group_updates_stream_cache.entity_has_changed(row.user_id, token)
4545
4343
4444 def process_replication_rows(self, stream_name, instance_name, token, rows):
4545 if stream_name == PresenceStream.NAME:
46 self._presence_id_gen.advance(token)
46 self._presence_id_gen.advance(instance_name, token)
4747 for row in rows:
4848 self.presence_stream_cache.entity_has_changed(row.user_id, token)
4949 self._get_presence_for_user.invalidate((row.user_id,))
1313 # See the License for the specific language governing permissions and
1414 # limitations under the License.
1515
16 from synapse.replication.slave.storage._slaved_id_tracker import SlavedIdTracker
1617 from synapse.replication.tcp.streams import PushRulesStream
1718 from synapse.storage.databases.main.push_rule import PushRulesWorkerStore
1819
2021
2122
2223 class SlavedPushRuleStore(SlavedEventStore, PushRulesWorkerStore):
23 def get_push_rules_stream_token(self):
24 return (
25 self._push_rules_stream_id_gen.get_current_token(),
26 self._stream_id_gen.get_current_token(),
27 )
28
2924 def get_max_push_rules_stream_id(self):
3025 return self._push_rules_stream_id_gen.get_current_token()
3126
3227 def process_replication_rows(self, stream_name, instance_name, token, rows):
28 # We assert this for the benefit of mypy
29 assert isinstance(self._push_rules_stream_id_gen, SlavedIdTracker)
30
3331 if stream_name == PushRulesStream.NAME:
34 self._push_rules_stream_id_gen.advance(token)
32 self._push_rules_stream_id_gen.advance(instance_name, token)
3533 for row in rows:
3634 self.get_push_rules_for_user.invalidate((row.user_id,))
3735 self.get_push_rules_enabled_for_user.invalidate((row.user_id,))
3333
3434 def process_replication_rows(self, stream_name, instance_name, token, rows):
3535 if stream_name == PushersStream.NAME:
36 self._pushers_id_gen.advance(token)
36 self._pushers_id_gen.advance(instance_name, token)
3737 return super().process_replication_rows(stream_name, instance_name, token, rows)
4545
4646 def process_replication_rows(self, stream_name, instance_name, token, rows):
4747 if stream_name == ReceiptsStream.NAME:
48 self._receipts_id_gen.advance(token)
48 self._receipts_id_gen.advance(instance_name, token)
4949 for row in rows:
5050 self.invalidate_caches_for_receipt(
5151 row.room_id, row.receipt_type, row.user_id
3232
3333 def process_replication_rows(self, stream_name, instance_name, token, rows):
3434 if stream_name == PublicRoomsStream.NAME:
35 self._public_room_id_gen.advance(token)
35 self._public_room_id_gen.advance(instance_name, token)
3636
3737 return super().process_replication_rows(stream_name, instance_name, token, rows)
1313 # limitations under the License.
1414 """A replication client for use by synapse workers.
1515 """
16 import heapq
1716 import logging
1817 from typing import TYPE_CHECKING, Dict, List, Tuple
1918
218217
219218 waiting_list = self._streams_to_waiters.setdefault(stream_name, [])
220219
221 # We insert into the list using heapq as it is more efficient than
222 # pushing then resorting each time.
223 heapq.heappush(waiting_list, (position, deferred))
220 waiting_list.append((position, deferred))
221 waiting_list.sort(key=lambda t: t[0])
224222
225223 # We measure here to get in flight counts and average waiting time.
226224 with Measure(self._clock, "repl.wait_for_stream_position"):
2020 import logging
2121 from typing import Tuple, Type
2222
23 from canonicaljson import json
24
25 from synapse.util import json_encoder as _json_encoder
23 from synapse.util import json_decoder, json_encoder
2624
2725 logger = logging.getLogger(__name__)
2826
124122 stream_name,
125123 instance_name,
126124 None if token == "batch" else int(token),
127 json.loads(row_json),
125 json_decoder.decode(row_json),
128126 )
129127
130128 def to_line(self):
133131 self.stream_name,
134132 self.instance_name,
135133 str(self.token) if self.token is not None else "batch",
136 _json_encoder.encode(self.row),
134 json_encoder.encode(self.row),
137135 )
138136 )
139137
358356 def from_line(cls, line):
359357 user_id, jsn = line.split(" ", 1)
360358
361 access_token, ip, user_agent, device_id, last_seen = json.loads(jsn)
359 access_token, ip, user_agent, device_id, last_seen = json_decoder.decode(jsn)
362360
363361 return cls(user_id, access_token, ip, user_agent, device_id, last_seen)
364362
366364 return (
367365 self.user_id
368366 + " "
369 + _json_encoder.encode(
367 + json_encoder.encode(
370368 (
371369 self.access_token,
372370 self.ip,
112112 PING_TIMEOUT_MS = PING_TIME * PING_TIMEOUT_MULTIPLIER
113113
114114
115 class ConnectionStates(object):
115 class ConnectionStates:
116116 CONNECTING = "connecting"
117117 ESTABLISHED = "established"
118118 PAUSED = "paused"
5757 )
5858
5959
60 class ReplicationStreamer(object):
60 class ReplicationStreamer:
6161 """Handles replication connections.
6262
6363 This needs to be poked when new replication data may be available. When new
7878 UpdateFunction = Callable[[str, Token, Token, int], Awaitable[StreamUpdateResult]]
7979
8080
81 class Stream(object):
81 class Stream:
8282 """Base class for the streams.
8383
8484 Provides a `get_updates()` function that returns new updates since the last
351351 )
352352
353353 def _current_token(self, instance_name: str) -> int:
354 push_rules_token, _ = self.store.get_push_rules_stream_token()
354 push_rules_token = self.store.get_max_push_rules_stream_id()
355355 return push_rules_token
356356
357357
404404 store = hs.get_datastore()
405405 super().__init__(
406406 hs.get_instance_name(),
407 store.get_cache_stream_token,
407 store.get_cache_stream_token_for_writer,
408408 store.get_all_updated_caches,
409409 )
410410
4848
4949
5050 @attr.s(slots=True, frozen=True)
51 class EventsStreamRow(object):
51 class EventsStreamRow:
5252 """A parsed row from the events replication stream"""
5353
5454 type = attr.ib() # str: the TypeId of one of the *EventsStreamRows
5555 data = attr.ib() # BaseEventsStreamRow
5656
5757
58 class BaseEventsStreamRow(object):
58 class BaseEventsStreamRow:
5959 """Base class for rows to be sent in the events stream.
6060
6161 Specifies how to identify, serialize and deserialize the different types.
4949 room_keys,
5050 room_upgrade_rest_servlet,
5151 sendtodevice,
52 shared_rooms,
5253 sync,
5354 tags,
5455 thirdparty,
124125 synapse.rest.admin.register_servlets_for_client_rest_resource(
125126 hs, client_resource
126127 )
128
129 # unstable
130 shared_rooms.register_servlets(hs, client_resource)
315315 join_rules_event = room_state.get((EventTypes.JoinRules, ""))
316316 if join_rules_event:
317317 if not (join_rules_event.content.get("join_rule") == JoinRules.PUBLIC):
318 # update_membership with an action of "invite" can raise a
319 # ShadowBanError. This is not handled since it is assumed that
320 # an admin isn't going to call this API with a shadow-banned user.
318321 await self.room_member_handler.update_membership(
319322 requester=requester,
320323 target=fake_requester.user,
7272 The parameters `from` and `limit` are required only for pagination.
7373 By default, a `limit` of 100 is used.
7474 The parameter `user_id` can be used to filter by user id.
75 The parameter `name` can be used to filter by user id or display name.
7576 The parameter `guests` can be used to exclude guest users.
7677 The parameter `deactivated` can be used to include deactivated users.
7778 """
8889 start = parse_integer(request, "from", default=0)
8990 limit = parse_integer(request, "limit", default=100)
9091 user_id = parse_string(request, "user_id", default=None)
92 name = parse_string(request, "name", default=None)
9193 guests = parse_boolean(request, "guests", default=True)
9294 deactivated = parse_boolean(request, "deactivated", default=False)
9395
9496 users, total = await self.store.get_users_paginate(
95 start, limit, user_id, guests, deactivated
97 start, limit, user_id, name, guests, deactivated
9698 )
9799 ret = {"users": users, "total": total}
98100 if len(users) >= limit:
2424 CLEANUP_PERIOD_MS = 1000 * 60 * 30 # 30 mins
2525
2626
27 class HttpTransactionCache(object):
27 class HttpTransactionCache:
2828 def __init__(self, hs):
2929 self.hs = hs
3030 self.auth = self.hs.get_auth()
1717
1818 from synapse.api.errors import Codes, LoginError, SynapseError
1919 from synapse.api.ratelimiting import Ratelimiter
20 from synapse.handlers.auth import (
21 convert_client_dict_legacy_fields_to_identifier,
22 login_id_phone_to_thirdparty,
23 )
2024 from synapse.http.server import finish_request
2125 from synapse.http.servlet import (
2226 RestServlet,
2731 from synapse.rest.client.v2_alpha._base import client_patterns
2832 from synapse.rest.well_known import WellKnownBuilder
2933 from synapse.types import JsonDict, UserID
30 from synapse.util.msisdn import phone_number_to_msisdn
3134 from synapse.util.threepids import canonicalise_email
3235
3336 logger = logging.getLogger(__name__)
34
35
36 def login_submission_legacy_convert(submission):
37 """
38 If the input login submission is an old style object
39 (ie. with top-level user / medium / address) convert it
40 to a typed object.
41 """
42 if "user" in submission:
43 submission["identifier"] = {"type": "m.id.user", "user": submission["user"]}
44 del submission["user"]
45
46 if "medium" in submission and "address" in submission:
47 submission["identifier"] = {
48 "type": "m.id.thirdparty",
49 "medium": submission["medium"],
50 "address": submission["address"],
51 }
52 del submission["medium"]
53 del submission["address"]
54
55
56 def login_id_thirdparty_from_phone(identifier):
57 """
58 Convert a phone login identifier type to a generic threepid identifier
59 Args:
60 identifier(dict): Login identifier dict of type 'm.id.phone'
61
62 Returns: Login identifier dict of type 'm.id.threepid'
63 """
64 if "country" not in identifier or (
65 # The specification requires a "phone" field, while Synapse used to require a "number"
66 # field. Accept both for backwards compatibility.
67 "phone" not in identifier
68 and "number" not in identifier
69 ):
70 raise SynapseError(400, "Invalid phone-type identifier")
71
72 # Accept both "phone" and "number" as valid keys in m.id.phone
73 phone_number = identifier.get("phone", identifier["number"])
74
75 msisdn = phone_number_to_msisdn(identifier["country"], phone_number)
76
77 return {"type": "m.id.thirdparty", "medium": "msisdn", "address": msisdn}
7837
7938
8039 class LoginRestServlet(RestServlet):
193152 login_submission.get("address"),
194153 login_submission.get("user"),
195154 )
196 login_submission_legacy_convert(login_submission)
197
198 if "identifier" not in login_submission:
199 raise SynapseError(400, "Missing param: identifier")
200
201 identifier = login_submission["identifier"]
202 if "type" not in identifier:
203 raise SynapseError(400, "Login identifier has no type")
155 identifier = convert_client_dict_legacy_fields_to_identifier(login_submission)
204156
205157 # convert phone type identifiers to generic threepids
206158 if identifier["type"] == "m.id.phone":
207 identifier = login_id_thirdparty_from_phone(identifier)
159 identifier = login_id_phone_to_thirdparty(identifier)
208160
209161 # convert threepid identifiers to user IDs
210162 if identifier["type"] == "m.id.thirdparty":
1111 # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
1212 # See the License for the specific language governing permissions and
1313 # limitations under the License.
14
1514
1615 from synapse.api.errors import (
1716 NotFoundError,
159158 return 200, {}
160159
161160 def notify_user(self, user_id):
162 stream_id, _ = self.store.get_push_rules_stream_token()
161 stream_id = self.store.get_max_push_rules_stream_id()
163162 self.notifier.on_new_event("push_rules_key", stream_id, users=[user_id])
164163
165 def set_rule_attr(self, user_id, spec, val):
164 async def set_rule_attr(self, user_id, spec, val):
166165 if spec["attr"] == "enabled":
167166 if isinstance(val, dict) and "enabled" in val:
168167 val = val["enabled"]
172171 # bools directly, so let's not break them.
173172 raise SynapseError(400, "Value for 'enabled' must be boolean")
174173 namespaced_rule_id = _namespaced_rule_id_from_spec(spec)
175 return self.store.set_push_rule_enabled(user_id, namespaced_rule_id, val)
174 return await self.store.set_push_rule_enabled(
175 user_id, namespaced_rule_id, val
176 )
176177 elif spec["attr"] == "actions":
177178 actions = val.get("actions")
178179 _check_actions(actions)
187188
188189 if namespaced_rule_id not in rule_ids:
189190 raise SynapseError(404, "Unknown rule %r" % (namespaced_rule_id,))
190 return self.store.set_push_rule_actions(
191 return await self.store.set_push_rule_actions(
191192 user_id, namespaced_rule_id, actions, is_default_rule
192193 )
193194 else:
2020 from typing import List, Optional
2121 from urllib import parse as urlparse
2222
23 from canonicaljson import json
24
2523 from synapse.api.constants import EventTypes, Membership
2624 from synapse.api.errors import (
2725 AuthError,
2826 Codes,
2927 HttpResponseException,
3028 InvalidClientCredentialsError,
29 ShadowBanError,
3130 SynapseError,
3231 )
3332 from synapse.api.filtering import Filter
4544 from synapse.storage.state import StateFilter
4645 from synapse.streams.config import PaginationConfig
4746 from synapse.types import RoomAlias, RoomID, StreamToken, ThirdPartyInstanceID, UserID
47 from synapse.util import json_decoder
48 from synapse.util.stringutils import random_string
4849
4950 MYPY = False
5051 if MYPY:
169170 room_id=room_id,
170171 event_type=event_type,
171172 state_key=state_key,
172 is_guest=requester.is_guest,
173173 )
174174
175175 if not data:
199199 if state_key is not None:
200200 event_dict["state_key"] = state_key
201201
202 if event_type == EventTypes.Member:
203 membership = content.get("membership", None)
204 event_id, _ = await self.room_member_handler.update_membership(
205 requester,
206 target=UserID.from_string(state_key),
207 room_id=room_id,
208 action=membership,
209 content=content,
210 )
211 else:
202 try:
203 if event_type == EventTypes.Member:
204 membership = content.get("membership", None)
205 event_id, _ = await self.room_member_handler.update_membership(
206 requester,
207 target=UserID.from_string(state_key),
208 room_id=room_id,
209 action=membership,
210 content=content,
211 )
212 else:
213 (
214 event,
215 _,
216 ) = await self.event_creation_handler.create_and_send_nonmember_event(
217 requester, event_dict, txn_id=txn_id
218 )
219 event_id = event.event_id
220 except ShadowBanError:
221 event_id = "$" + random_string(43)
222
223 set_tag("event_id", event_id)
224 ret = {"event_id": event_id}
225 return 200, ret
226
227
228 # TODO: Needs unit testing for generic events + feedback
229 class RoomSendEventRestServlet(TransactionRestServlet):
230 def __init__(self, hs):
231 super(RoomSendEventRestServlet, self).__init__(hs)
232 self.event_creation_handler = hs.get_event_creation_handler()
233 self.auth = hs.get_auth()
234
235 def register(self, http_server):
236 # /rooms/$roomid/send/$event_type[/$txn_id]
237 PATTERNS = "/rooms/(?P<room_id>[^/]*)/send/(?P<event_type>[^/]*)"
238 register_txn_path(self, PATTERNS, http_server, with_get=True)
239
240 async def on_POST(self, request, room_id, event_type, txn_id=None):
241 requester = await self.auth.get_user_by_req(request, allow_guest=True)
242 content = parse_json_object_from_request(request)
243
244 event_dict = {
245 "type": event_type,
246 "content": content,
247 "room_id": room_id,
248 "sender": requester.user.to_string(),
249 }
250
251 if b"ts" in request.args and requester.app_service:
252 event_dict["origin_server_ts"] = parse_integer(request, "ts", 0)
253
254 try:
212255 (
213256 event,
214257 _,
216259 requester, event_dict, txn_id=txn_id
217260 )
218261 event_id = event.event_id
262 except ShadowBanError:
263 event_id = "$" + random_string(43)
219264
220265 set_tag("event_id", event_id)
221 ret = {"event_id": event_id}
222 return 200, ret
223
224
225 # TODO: Needs unit testing for generic events + feedback
226 class RoomSendEventRestServlet(TransactionRestServlet):
227 def __init__(self, hs):
228 super(RoomSendEventRestServlet, self).__init__(hs)
229 self.event_creation_handler = hs.get_event_creation_handler()
230 self.auth = hs.get_auth()
231
232 def register(self, http_server):
233 # /rooms/$roomid/send/$event_type[/$txn_id]
234 PATTERNS = "/rooms/(?P<room_id>[^/]*)/send/(?P<event_type>[^/]*)"
235 register_txn_path(self, PATTERNS, http_server, with_get=True)
236
237 async def on_POST(self, request, room_id, event_type, txn_id=None):
238 requester = await self.auth.get_user_by_req(request, allow_guest=True)
239 content = parse_json_object_from_request(request)
240
241 event_dict = {
242 "type": event_type,
243 "content": content,
244 "room_id": room_id,
245 "sender": requester.user.to_string(),
246 }
247
248 if b"ts" in request.args and requester.app_service:
249 event_dict["origin_server_ts"] = parse_integer(request, "ts", 0)
250
251 event, _ = await self.event_creation_handler.create_and_send_nonmember_event(
252 requester, event_dict, txn_id=txn_id
253 )
254
255 set_tag("event_id", event.event_id)
256 return 200, {"event_id": event.event_id}
266 return 200, {"event_id": event_id}
257267
258268 def on_GET(self, request, room_id, event_type, txn_id):
259269 return 200, "Not implemented"
518528 filter_str = parse_string(request, b"filter", encoding="utf-8")
519529 if filter_str:
520530 filter_json = urlparse.unquote(filter_str)
521 event_filter = Filter(json.loads(filter_json)) # type: Optional[Filter]
531 event_filter = Filter(
532 json_decoder.decode(filter_json)
533 ) # type: Optional[Filter]
522534 if (
523535 event_filter
524536 and event_filter.filter_json.get("event_format", "client")
630642 filter_str = parse_string(request, b"filter", encoding="utf-8")
631643 if filter_str:
632644 filter_json = urlparse.unquote(filter_str)
633 event_filter = Filter(json.loads(filter_json)) # type: Optional[Filter]
645 event_filter = Filter(
646 json_decoder.decode(filter_json)
647 ) # type: Optional[Filter]
634648 else:
635649 event_filter = None
636650
715729 content = {}
716730
717731 if membership_action == "invite" and self._has_3pid_invite_keys(content):
718 await self.room_member_handler.do_3pid_invite(
719 room_id,
720 requester.user,
721 content["medium"],
722 content["address"],
723 content["id_server"],
724 requester,
725 txn_id,
726 content.get("id_access_token"),
727 )
732 try:
733 await self.room_member_handler.do_3pid_invite(
734 room_id,
735 requester.user,
736 content["medium"],
737 content["address"],
738 content["id_server"],
739 requester,
740 txn_id,
741 content.get("id_access_token"),
742 )
743 except ShadowBanError:
744 # Pretend the request succeeded.
745 pass
728746 return 200, {}
729747
730748 target = requester.user
736754 if "reason" in content:
737755 event_content = {"reason": content["reason"]}
738756
739 await self.room_member_handler.update_membership(
740 requester=requester,
741 target=target,
742 room_id=room_id,
743 action=membership_action,
744 txn_id=txn_id,
745 third_party_signed=content.get("third_party_signed", None),
746 content=event_content,
747 )
757 try:
758 await self.room_member_handler.update_membership(
759 requester=requester,
760 target=target,
761 room_id=room_id,
762 action=membership_action,
763 txn_id=txn_id,
764 third_party_signed=content.get("third_party_signed", None),
765 content=event_content,
766 )
767 except ShadowBanError:
768 # Pretend the request succeeded.
769 pass
748770
749771 return_value = {}
750772
782804 requester = await self.auth.get_user_by_req(request)
783805 content = parse_json_object_from_request(request)
784806
785 event, _ = await self.event_creation_handler.create_and_send_nonmember_event(
786 requester,
787 {
788 "type": EventTypes.Redaction,
789 "content": content,
790 "room_id": room_id,
791 "sender": requester.user.to_string(),
792 "redacts": event_id,
793 },
794 txn_id=txn_id,
795 )
796
797 set_tag("event_id", event.event_id)
798 return 200, {"event_id": event.event_id}
807 try:
808 (
809 event,
810 _,
811 ) = await self.event_creation_handler.create_and_send_nonmember_event(
812 requester,
813 {
814 "type": EventTypes.Redaction,
815 "content": content,
816 "room_id": room_id,
817 "sender": requester.user.to_string(),
818 "redacts": event_id,
819 },
820 txn_id=txn_id,
821 )
822 event_id = event.event_id
823 except ShadowBanError:
824 event_id = "$" + random_string(43)
825
826 set_tag("event_id", event_id)
827 return 200, {"event_id": event_id}
799828
800829 def on_PUT(self, request, room_id, event_id, txn_id):
801830 set_tag("txn_id", txn_id)
838867 # Limit timeout to stop people from setting silly typing timeouts.
839868 timeout = min(content.get("timeout", 30000), 120000)
840869
841 if content["typing"]:
842 await self.typing_handler.started_typing(
843 target_user=target_user,
844 auth_user=requester.user,
845 room_id=room_id,
846 timeout=timeout,
847 )
848 else:
849 await self.typing_handler.stopped_typing(
850 target_user=target_user, auth_user=requester.user, room_id=room_id
851 )
870 try:
871 if content["typing"]:
872 await self.typing_handler.started_typing(
873 target_user=target_user,
874 requester=requester,
875 room_id=room_id,
876 timeout=timeout,
877 )
878 else:
879 await self.typing_handler.stopped_typing(
880 target_user=target_user, requester=requester, room_id=room_id
881 )
882 except ShadowBanError:
883 # Pretend this worked without error.
884 pass
852885
853886 return 200, {}
854887
1414 # See the License for the specific language governing permissions and
1515 # limitations under the License.
1616 import logging
17 import random
1718 from http import HTTPStatus
1819
1920 from synapse.api.constants import LoginType
3132 parse_json_object_from_request,
3233 parse_string,
3334 )
34 from synapse.push.mailer import Mailer, load_jinja2_templates
35 from synapse.push.mailer import Mailer
3536 from synapse.util.msisdn import phone_number_to_msisdn
3637 from synapse.util.stringutils import assert_valid_client_secret, random_string
3738 from synapse.util.threepids import canonicalise_email, check_3pid_allowed
5253 self.identity_handler = hs.get_handlers().identity_handler
5354
5455 if self.config.threepid_behaviour_email == ThreepidBehaviour.LOCAL:
55 template_html, template_text = load_jinja2_templates(
56 self.config.email_template_dir,
57 [
58 self.config.email_password_reset_template_html,
59 self.config.email_password_reset_template_text,
60 ],
61 apply_format_ts_filter=True,
62 apply_mxc_to_http_filter=True,
63 public_baseurl=self.config.public_baseurl,
64 )
6556 self.mailer = Mailer(
6657 hs=self.hs,
6758 app_name=self.config.email_app_name,
68 template_html=template_html,
69 template_text=template_text,
59 template_html=self.config.email_password_reset_template_html,
60 template_text=self.config.email_password_reset_template_text,
7061 )
7162
7263 async def on_POST(self, request):
118109 if self.config.request_token_inhibit_3pid_errors:
119110 # Make the client think the operation succeeded. See the rationale in the
120111 # comments for request_token_inhibit_3pid_errors.
112 # Also wait for some random amount of time between 100ms and 1s to make it
113 # look like we did something.
114 await self.hs.clock.sleep(random.randint(1, 10) / 10)
121115 return 200, {"sid": random_string(16)}
122116
123117 raise SynapseError(400, "Email not found", Codes.THREEPID_NOT_FOUND)
168162 self.clock = hs.get_clock()
169163 self.store = hs.get_datastore()
170164 if self.config.threepid_behaviour_email == ThreepidBehaviour.LOCAL:
171 (self.failure_email_template,) = load_jinja2_templates(
172 self.config.email_template_dir,
173 [self.config.email_password_reset_template_failure_html],
165 self._failure_email_template = (
166 self.config.email_password_reset_template_failure_html
174167 )
175168
176169 async def on_GET(self, request, medium):
213206 return None
214207
215208 # Otherwise show the success template
216 html = self.config.email_password_reset_template_success_html
209 html = self.config.email_password_reset_template_success_html_content
217210 status_code = 200
218211 except ThreepidValidationError as e:
219212 status_code = e.code
220213
221214 # Show a failure page with a reason
222215 template_vars = {"failure_reason": e.msg}
223 html = self.failure_email_template.render(**template_vars)
216 html = self._failure_email_template.render(**template_vars)
224217
225218 respond_with_html(request, status_code, html)
226219
410403 self.store = self.hs.get_datastore()
411404
412405 if self.config.threepid_behaviour_email == ThreepidBehaviour.LOCAL:
413 template_html, template_text = load_jinja2_templates(
414 self.config.email_template_dir,
415 [
416 self.config.email_add_threepid_template_html,
417 self.config.email_add_threepid_template_text,
418 ],
419 public_baseurl=self.config.public_baseurl,
420 )
421406 self.mailer = Mailer(
422407 hs=self.hs,
423408 app_name=self.config.email_app_name,
424 template_html=template_html,
425 template_text=template_text,
409 template_html=self.config.email_add_threepid_template_html,
410 template_text=self.config.email_add_threepid_template_text,
426411 )
427412
428413 async def on_POST(self, request):
466451 if self.config.request_token_inhibit_3pid_errors:
467452 # Make the client think the operation succeeded. See the rationale in the
468453 # comments for request_token_inhibit_3pid_errors.
454 # Also wait for some random amount of time between 100ms and 1s to make it
455 # look like we did something.
456 await self.hs.clock.sleep(random.randint(1, 10) / 10)
469457 return 200, {"sid": random_string(16)}
470458
471459 raise SynapseError(400, "Email is already in use", Codes.THREEPID_IN_USE)
534522 if self.hs.config.request_token_inhibit_3pid_errors:
535523 # Make the client think the operation succeeded. See the rationale in the
536524 # comments for request_token_inhibit_3pid_errors.
525 # Also wait for some random amount of time between 100ms and 1s to make it
526 # look like we did something.
527 await self.hs.clock.sleep(random.randint(1, 10) / 10)
537528 return 200, {"sid": random_string(16)}
538529
539530 raise SynapseError(400, "MSISDN is already in use", Codes.THREEPID_IN_USE)
577568 self.clock = hs.get_clock()
578569 self.store = hs.get_datastore()
579570 if self.config.threepid_behaviour_email == ThreepidBehaviour.LOCAL:
580 (self.failure_email_template,) = load_jinja2_templates(
581 self.config.email_template_dir,
582 [self.config.email_add_threepid_template_failure_html],
571 self._failure_email_template = (
572 self.config.email_add_threepid_template_failure_html
583573 )
584574
585575 async def on_GET(self, request):
630620
631621 # Show a failure page with a reason
632622 template_vars = {"failure_reason": e.msg}
633 html = self.failure_email_template.render(**template_vars)
623 html = self._failure_email_template.render(**template_vars)
634624
635625 respond_with_html(request, status_code, html)
636626
1515
1616 import logging
1717
18 from synapse.api.errors import SynapseError
1819 from synapse.http.servlet import RestServlet, parse_json_object_from_request
1920 from synapse.types import GroupID
2021
323324 async def on_GET(self, request, group_id):
324325 requester = await self.auth.get_user_by_req(request, allow_guest=True)
325326 requester_user_id = requester.user.to_string()
327
328 if not GroupID.is_valid(group_id):
329 raise SynapseError(400, "%s was not legal group ID" % (group_id,))
326330
327331 result = await self.groups_handler.get_rooms_in_group(
328332 group_id, requester_user_id
1515
1616 import hmac
1717 import logging
18 import random
1819 from typing import List, Union
1920
2021 import synapse
4344 parse_json_object_from_request,
4445 parse_string,
4546 )
46 from synapse.push.mailer import load_jinja2_templates
47 from synapse.push.mailer import Mailer
4748 from synapse.util.msisdn import phone_number_to_msisdn
4849 from synapse.util.ratelimitutils import FederationRateLimiter
4950 from synapse.util.stringutils import assert_valid_client_secret, random_string
8081 self.config = hs.config
8182
8283 if self.hs.config.threepid_behaviour_email == ThreepidBehaviour.LOCAL:
83 from synapse.push.mailer import Mailer, load_jinja2_templates
84
85 template_html, template_text = load_jinja2_templates(
86 self.config.email_template_dir,
87 [
88 self.config.email_registration_template_html,
89 self.config.email_registration_template_text,
90 ],
91 apply_format_ts_filter=True,
92 apply_mxc_to_http_filter=True,
93 public_baseurl=self.config.public_baseurl,
94 )
9584 self.mailer = Mailer(
9685 hs=self.hs,
9786 app_name=self.config.email_app_name,
98 template_html=template_html,
99 template_text=template_text,
87 template_html=self.config.email_registration_template_html,
88 template_text=self.config.email_registration_template_text,
10089 )
10190
10291 async def on_POST(self, request):
142131 if self.hs.config.request_token_inhibit_3pid_errors:
143132 # Make the client think the operation succeeded. See the rationale in the
144133 # comments for request_token_inhibit_3pid_errors.
134 # Also wait for some random amount of time between 100ms and 1s to make it
135 # look like we did something.
136 await self.hs.clock.sleep(random.randint(1, 10) / 10)
145137 return 200, {"sid": random_string(16)}
146138
147139 raise SynapseError(400, "Email is already in use", Codes.THREEPID_IN_USE)
214206 if self.hs.config.request_token_inhibit_3pid_errors:
215207 # Make the client think the operation succeeded. See the rationale in the
216208 # comments for request_token_inhibit_3pid_errors.
209 # Also wait for some random amount of time between 100ms and 1s to make it
210 # look like we did something.
211 await self.hs.clock.sleep(random.randint(1, 10) / 10)
217212 return 200, {"sid": random_string(16)}
218213
219214 raise SynapseError(
261256 self.store = hs.get_datastore()
262257
263258 if self.config.threepid_behaviour_email == ThreepidBehaviour.LOCAL:
264 (self.failure_email_template,) = load_jinja2_templates(
265 self.config.email_template_dir,
266 [self.config.email_registration_template_failure_html],
267 )
268
269 if self.config.threepid_behaviour_email == ThreepidBehaviour.LOCAL:
270 (self.failure_email_template,) = load_jinja2_templates(
271 self.config.email_template_dir,
272 [self.config.email_registration_template_failure_html],
259 self._failure_email_template = (
260 self.config.email_registration_template_failure_html
273261 )
274262
275263 async def on_GET(self, request, medium):
317305
318306 # Show a failure page with a reason
319307 template_vars = {"failure_reason": e.msg}
320 html = self.failure_email_template.render(**template_vars)
308 html = self._failure_email_template.render(**template_vars)
321309
322310 respond_with_html(request, status_code, html)
323311
609597 Codes.THREEPID_IN_USE,
610598 )
611599
600 entries = await self.store.get_user_agents_ips_to_ui_auth_session(
601 session_id
602 )
603
612604 registered_user_id = await self.registration_handler.register_user(
613605 localpart=desired_username,
614606 password_hash=password_hash,
615607 guest_access_token=guest_access_token,
616608 threepid=threepid,
617609 address=client_addr,
610 user_agent_ips=entries,
618611 )
619612 # Necessary due to auth checks prior to the threepid being
620613 # written to the db
664657 (object) params: registration parameters, from which we pull
665658 device_id, initial_device_name and inhibit_login
666659 Returns:
667 (object) dictionary for response from /register
660 dictionary for response from /register
668661 """
669662 result = {"user_id": user_id, "home_server": self.hs.hostname}
670663 if not params.get("inhibit_login", False):
2121 import logging
2222
2323 from synapse.api.constants import EventTypes, RelationTypes
24 from synapse.api.errors import SynapseError
24 from synapse.api.errors import ShadowBanError, SynapseError
2525 from synapse.http.servlet import (
2626 RestServlet,
2727 parse_integer,
3434 PaginationChunk,
3535 RelationPaginationToken,
3636 )
37 from synapse.util.stringutils import random_string
3738
3839 from ._base import client_patterns
3940
110111 "sender": requester.user.to_string(),
111112 }
112113
113 event, _ = await self.event_creation_handler.create_and_send_nonmember_event(
114 requester, event_dict=event_dict, txn_id=txn_id
115 )
116
117 return 200, {"event_id": event.event_id}
114 try:
115 (
116 event,
117 _,
118 ) = await self.event_creation_handler.create_and_send_nonmember_event(
119 requester, event_dict=event_dict, txn_id=txn_id
120 )
121 event_id = event.event_id
122 except ShadowBanError:
123 event_id = "$" + random_string(43)
124
125 return 200, {"event_id": event_id}
118126
119127
120128 class RelationPaginationServlet(RestServlet):
1414
1515 import logging
1616
17 from synapse.api.errors import Codes, SynapseError
17 from synapse.api.errors import Codes, ShadowBanError, SynapseError
1818 from synapse.api.room_versions import KNOWN_ROOM_VERSIONS
1919 from synapse.http.servlet import (
2020 RestServlet,
2121 assert_params_in_dict,
2222 parse_json_object_from_request,
2323 )
24 from synapse.util import stringutils
2425
2526 from ._base import client_patterns
2627
6162
6263 content = parse_json_object_from_request(request)
6364 assert_params_in_dict(content, ("new_version",))
64 new_version = content["new_version"]
6565
6666 new_version = KNOWN_ROOM_VERSIONS.get(content["new_version"])
6767 if new_version is None:
7171 Codes.UNSUPPORTED_ROOM_VERSION,
7272 )
7373
74 new_room_id = await self._room_creation_handler.upgrade_room(
75 requester, room_id, new_version
76 )
74 try:
75 new_room_id = await self._room_creation_handler.upgrade_room(
76 requester, room_id, new_version
77 )
78 except ShadowBanError:
79 # Generate a random room ID.
80 new_room_id = stringutils.random_string(18)
7781
7882 ret = {"replacement_room": new_room_id}
7983
0 # -*- coding: utf-8 -*-
1 # Copyright 2020 Half-Shot
2 #
3 # Licensed under the Apache License, Version 2.0 (the "License");
4 # you may not use this file except in compliance with the License.
5 # You may obtain a copy of the License at
6 #
7 # http://www.apache.org/licenses/LICENSE-2.0
8 #
9 # Unless required by applicable law or agreed to in writing, software
10 # distributed under the License is distributed on an "AS IS" BASIS,
11 # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 # See the License for the specific language governing permissions and
13 # limitations under the License.
14 import logging
15
16 from synapse.api.errors import Codes, SynapseError
17 from synapse.http.servlet import RestServlet
18 from synapse.types import UserID
19
20 from ._base import client_patterns
21
22 logger = logging.getLogger(__name__)
23
24
25 class UserSharedRoomsServlet(RestServlet):
26 """
27 GET /uk.half-shot.msc2666/user/shared_rooms/{user_id} HTTP/1.1
28 """
29
30 PATTERNS = client_patterns(
31 "/uk.half-shot.msc2666/user/shared_rooms/(?P<user_id>[^/]*)",
32 releases=(), # This is an unstable feature
33 )
34
35 def __init__(self, hs):
36 super(UserSharedRoomsServlet, self).__init__()
37 self.auth = hs.get_auth()
38 self.store = hs.get_datastore()
39 self.user_directory_active = hs.config.update_user_directory
40
41 async def on_GET(self, request, user_id):
42
43 if not self.user_directory_active:
44 raise SynapseError(
45 code=400,
46 msg="The user directory is disabled on this server. Cannot determine shared rooms.",
47 errcode=Codes.FORBIDDEN,
48 )
49
50 UserID.from_string(user_id)
51
52 requester = await self.auth.get_user_by_req(request)
53 if user_id == requester.user.to_string():
54 raise SynapseError(
55 code=400,
56 msg="You cannot request a list of shared rooms with yourself",
57 errcode=Codes.FORBIDDEN,
58 )
59 rooms = await self.store.get_shared_rooms_for_users(
60 requester.user.to_string(), user_id
61 )
62
63 return 200, {"joined": list(rooms)}
64
65
66 def register_servlets(hs, http_server):
67 UserSharedRoomsServlet(hs).register(http_server)
1515 import itertools
1616 import logging
1717
18 from canonicaljson import json
19
2018 from synapse.api.constants import PresenceState
2119 from synapse.api.errors import Codes, StoreError, SynapseError
2220 from synapse.api.filtering import DEFAULT_FILTER_COLLECTION, FilterCollection
2826 from synapse.handlers.sync import SyncConfig
2927 from synapse.http.servlet import RestServlet, parse_boolean, parse_integer, parse_string
3028 from synapse.types import StreamToken
29 from synapse.util import json_decoder
3130
3231 from ._base import client_patterns, set_timeline_upper_limit
3332
124123 filter_collection = DEFAULT_FILTER_COLLECTION
125124 elif filter_id.startswith("{"):
126125 try:
127 filter_object = json.loads(filter_id)
126 filter_object = json_decoder.decode(filter_id)
128127 set_timeline_upper_limit(
129128 filter_object, self.hs.config.filter_timeline_limit
130129 )
425424 result["ephemeral"] = {"events": ephemeral_events}
426425 result["unread_notifications"] = room.unread_notifications
427426 result["summary"] = room.summary
427 result["org.matrix.msc2654.unread_count"] = room.unread_count
428428
429429 return result
430430
1818 import logging
1919 import re
2020
21 from synapse.api.constants import RoomCreationPreset
2122 from synapse.http.servlet import RestServlet
2223
2324 logger = logging.getLogger(__name__)
2930 def __init__(self, hs):
3031 super(VersionsRestServlet, self).__init__()
3132 self.config = hs.config
33
34 # Calculate these once since they shouldn't change after start-up.
35 self.e2ee_forced_public = (
36 RoomCreationPreset.PUBLIC_CHAT
37 in self.config.encryption_enabled_by_default_for_room_presets
38 )
39 self.e2ee_forced_private = (
40 RoomCreationPreset.PRIVATE_CHAT
41 in self.config.encryption_enabled_by_default_for_room_presets
42 )
43 self.e2ee_forced_trusted_private = (
44 RoomCreationPreset.TRUSTED_PRIVATE_CHAT
45 in self.config.encryption_enabled_by_default_for_room_presets
46 )
3247
3348 def on_GET(self, request):
3449 return (
5974 "org.matrix.e2e_cross_signing": True,
6075 # Implements additional endpoints as described in MSC2432
6176 "org.matrix.msc2432": True,
77 # Implements additional endpoints as described in MSC2666
78 "uk.half-shot.msc2666": True,
79 # Whether new rooms will be set to encrypted or not (based on presets).
80 "io.element.e2ee_forced.public": self.e2ee_forced_public,
81 "io.element.e2ee_forced.private": self.e2ee_forced_private,
82 "io.element.e2ee_forced.trusted_private": self.e2ee_forced_trusted_private,
6283 },
6384 },
6485 )
1414 import logging
1515 from typing import Dict, Set
1616
17 from canonicaljson import encode_canonical_json, json
1817 from signedjson.sign import sign_json
1918
2019 from synapse.api.errors import Codes, SynapseError
2120 from synapse.crypto.keyring import ServerKeyFetcher
22 from synapse.http.server import DirectServeJsonResource, respond_with_json_bytes
21 from synapse.http.server import DirectServeJsonResource, respond_with_json
2322 from synapse.http.servlet import parse_integer, parse_json_object_from_request
23 from synapse.util import json_decoder
2424
2525 logger = logging.getLogger(__name__)
2626
2727
2828 class RemoteKey(DirectServeJsonResource):
29 """HTTP resource for retreiving the TLS certificate and NACL signature
29 """HTTP resource for retrieving the TLS certificate and NACL signature
3030 verification keys for a collection of servers. Checks that the reported
3131 X.509 TLS certificate matches the one used in the HTTPS connection. Checks
3232 that the NACL signature for the remote server is valid. Returns a dict of
208208 # Cast to bytes since postgresql returns a memoryview.
209209 json_results.add(bytes(result["key_json"]))
210210
211 # If there is a cache miss, request the missing keys, then recurse (and
212 # ensure the result is sent).
211213 if cache_misses and query_remote_on_cache_miss:
212214 await self.fetcher.get_keys(cache_misses)
213215 await self.query_keys(request, query, query_remote_on_cache_miss=False)
214216 else:
215217 signed_keys = []
216218 for key_json in json_results:
217 key_json = json.loads(key_json.decode("utf-8"))
219 key_json = json_decoder.decode(key_json.decode("utf-8"))
218220 for signing_key in self.config.key_server_signing_keys:
219221 key_json = sign_json(key_json, self.config.server_name, signing_key)
220222
222224
223225 results = {"server_keys": signed_keys}
224226
225 respond_with_json_bytes(request, 200, encode_canonical_json(results))
227 respond_with_json(request, 200, results, canonical_json=True)
234234 finish_request(request)
235235
236236
237 class Responder(object):
237 class Responder:
238238 """Represents a response that can be streamed to the requester.
239239
240240 Responder is a context manager which *must* be used, so that any resources
259259 pass
260260
261261
262 class FileInfo(object):
262 class FileInfo:
263263 """Details about a requested/uploaded file.
264264
265265 Attributes:
3232 return _wrapped
3333
3434
35 class MediaFilePaths(object):
35 class MediaFilePaths:
3636 """Describes where files are stored on disk.
3737
3838 Most of the functions have a `*_rel` variant which returns a file path that
6161 UPDATE_RECENTLY_ACCESSED_TS = 60 * 1000
6262
6363
64 class MediaRepository(object):
64 class MediaRepository:
6565 def __init__(self, hs):
6666 self.hs = hs
6767 self.auth = hs.get_auth()
3333 logger = logging.getLogger(__name__)
3434
3535
36 class MediaStorage(object):
36 class MediaStorage:
3737 """Responsible for storing/fetching files from local sources.
3838
3939 Args:
3030 }
3131
3232
33 class Thumbnailer(object):
33 class Thumbnailer:
3434
3535 FORMATS = {"image/jpeg": "JPEG", "image/png": "PNG"}
3636
1212 # See the License for the specific language governing permissions and
1313 # limitations under the License.
1414
15 import json
1615 import logging
1716
1817 from twisted.web.resource import Resource
1918
2019 from synapse.http.server import set_cors_headers
20 from synapse.util import json_encoder
2121
2222 logger = logging.getLogger(__name__)
2323
2424
25 class WellKnownBuilder(object):
25 class WellKnownBuilder:
2626 """Utility to construct the well-known response
2727
2828 Args:
6666
6767 logger.debug("returning: %s", r)
6868 request.setHeader(b"Content-Type", b"application/json")
69 return json.dumps(r).encode("utf-8")
69 return json_encoder.encode(r).encode("utf-8")
3636 import binascii
3737 import os
3838
39 class Secrets(object):
39 class Secrets:
4040 def token_bytes(self, nbytes=32):
4141 return os.urandom(nbytes)
4242
113113 from synapse.types import DomainSpecificString
114114 from synapse.util import Clock
115115 from synapse.util.distributor import Distributor
116 from synapse.util.ratelimitutils import FederationRateLimiter
116117 from synapse.util.stringutils import random_string
117118
118119 logger = logging.getLogger(__name__)
641642 def get_replication_streams(self) -> Dict[str, Stream]:
642643 return {stream.NAME: stream(self) for stream in STREAMS_MAP.values()}
643644
645 @cache_in_self
646 def get_federation_ratelimiter(self) -> FederationRateLimiter:
647 return FederationRateLimiter(self.clock, config=self.config.rc_federation)
648
644649 async def remove_pusher(self, app_id: str, push_key: str, user_id: str):
645650 return await self.get_pusherpool().remove_pusher(app_id, push_key, user_id)
646651
2222 logger = logging.getLogger(__name__)
2323
2424
25 class ConsentServerNotices(object):
25 class ConsentServerNotices:
2626 """Keeps track of whether we need to send users server_notices about
2727 privacy policy consent, and sends one if we do.
2828 """
2626 logger = logging.getLogger(__name__)
2727
2828
29 class ResourceLimitsServerNotices(object):
29 class ResourceLimitsServerNotices:
3030 """ Keeps track of whether the server has reached it's resource limit and
3131 ensures that the client is kept up to date.
3232 """
2424 SERVER_NOTICE_ROOM_TAG = "m.server_notice"
2525
2626
27 class ServerNoticesManager(object):
27 class ServerNoticesManager:
2828 def __init__(self, hs):
2929 """
3030
1919 )
2020
2121
22 class ServerNoticesSender(object):
22 class ServerNoticesSender:
2323 """A centralised place which sends server notices automatically when
2424 Certain Events take place
2525 """
1313 # limitations under the License.
1414
1515
16 class WorkerServerNoticesSender(object):
16 class WorkerServerNoticesSender:
1717 """Stub impl of ServerNoticesSender which does nothing"""
1818
1919 def __init__(self, hs):
1212 # See the License for the specific language governing permissions and
1313 # limitations under the License.
1414 import logging
15 from enum import Enum
1516
1617 from twisted.internet import defer
1718
2425 logger = logging.getLogger(__name__)
2526
2627
27 class SpamCheckerApi(object):
28 class RegistrationBehaviour(Enum):
29 """
30 Enum to define whether a registration request should allowed, denied, or shadow-banned.
31 """
32
33 ALLOW = "allow"
34 SHADOW_BAN = "shadow_ban"
35 DENY = "deny"
36
37
38 class SpamCheckerApi:
2839 """A proxy object that gets passed to spam checkers so they can get
2940 access to rooms and other relevant information.
3041 """
4758 twisted.internet.defer.Deferred[list(synapse.events.FrozenEvent)]:
4859 The filtered state events in the room.
4960 """
50 state_ids = yield self._store.get_filtered_current_state_ids(
51 room_id=room_id, state_filter=StateFilter.from_types(types)
61 state_ids = yield defer.ensureDeferred(
62 self._store.get_filtered_current_state_ids(
63 room_id=room_id, state_filter=StateFilter.from_types(types)
64 )
5265 )
53 state = yield self._store.get_events(state_ids.values())
66 state = yield defer.ensureDeferred(self._store.get_events(state_ids.values()))
5467 return state.values()
1515
1616 import logging
1717 from collections import namedtuple
18 from typing import Awaitable, Dict, Iterable, List, Optional, Set
18 from typing import (
19 Awaitable,
20 Dict,
21 Iterable,
22 List,
23 Optional,
24 Sequence,
25 Set,
26 Union,
27 cast,
28 overload,
29 )
1930
2031 import attr
2132 from frozendict import frozendict
2233 from prometheus_client import Histogram
34 from typing_extensions import Literal
2335
2436 from synapse.api.constants import EventTypes
2537 from synapse.api.room_versions import KNOWN_ROOM_VERSIONS, StateResolutionVersions
2941 from synapse.state import v1, v2
3042 from synapse.storage.databases.main.events_worker import EventRedactBehaviour
3143 from synapse.storage.roommember import ProfileInfo
32 from synapse.types import StateMap
44 from synapse.types import Collection, MutableStateMap, StateMap
3345 from synapse.util import Clock
3446 from synapse.util.async_helpers import Linearizer
3547 from synapse.util.caches.expiringcache import ExpiringCache
6476 return s
6577
6678
67 class _StateCacheEntry(object):
79 class _StateCacheEntry:
6880 __slots__ = ["state", "state_group", "state_id", "prev_group", "delta_ids"]
6981
70 def __init__(self, state, state_group, prev_group=None, delta_ids=None):
71 # dict[(str, str), str] map from (type, state_key) to event_id
82 def __init__(
83 self,
84 state: StateMap[str],
85 state_group: Optional[int],
86 prev_group: Optional[int] = None,
87 delta_ids: Optional[StateMap[str]] = None,
88 ):
89 # A map from (type, state_key) to event_id.
7290 self.state = frozendict(state)
7391
7492 # the ID of a state group if one and only one is involved.
94112 return len(self.state)
95113
96114
97 class StateHandler(object):
115 class StateHandler:
98116 """Fetches bits of state from the stores, and does state resolution
99117 where necessary
100118 """
106124 self.hs = hs
107125 self._state_resolution_handler = hs.get_state_resolution_handler()
108126
127 @overload
109128 async def get_current_state(
110 self, room_id, event_type=None, state_key="", latest_event_ids=None
111 ):
112 """ Retrieves the current state for the room. This is done by
129 self,
130 room_id: str,
131 event_type: Literal[None] = None,
132 state_key: str = "",
133 latest_event_ids: Optional[List[str]] = None,
134 ) -> StateMap[EventBase]:
135 ...
136
137 @overload
138 async def get_current_state(
139 self,
140 room_id: str,
141 event_type: str,
142 state_key: str = "",
143 latest_event_ids: Optional[List[str]] = None,
144 ) -> Optional[EventBase]:
145 ...
146
147 async def get_current_state(
148 self,
149 room_id: str,
150 event_type: Optional[str] = None,
151 state_key: str = "",
152 latest_event_ids: Optional[List[str]] = None,
153 ) -> Union[Optional[EventBase], StateMap[EventBase]]:
154 """Retrieves the current state for the room. This is done by
113155 calling `get_latest_events_in_room` to get the leading edges of the
114156 event graph and then resolving any of the state conflicts.
115157
116158 This is equivalent to getting the state of an event that were to send
117159 next before receiving any new events.
118160
119 If `event_type` is specified, then the method returns only the one
120 event (or None) with that `event_type` and `state_key`.
121
122 Returns:
123 map from (type, state_key) to event
161 Returns:
162 If `event_type` is specified, then the method returns only the one
163 event (or None) with that `event_type` and `state_key`.
164
165 Otherwise, a map from (type, state_key) to event.
124166 """
125167 if not latest_event_ids:
126168 latest_event_ids = await self.store.get_latest_event_ids_in_room(room_id)
169 assert latest_event_ids is not None
127170
128171 logger.debug("calling resolve_state_groups from get_current_state")
129172 ret = await self.resolve_state_groups_for_events(room_id, latest_event_ids)
139182 state_map = await self.store.get_events(
140183 list(state.values()), get_prev_content=False
141184 )
142 state = {
185 return {
143186 key: state_map[e_id] for key, e_id in state.items() if e_id in state_map
144187 }
145188
146 return state
147
148 async def get_current_state_ids(self, room_id, latest_event_ids=None):
189 async def get_current_state_ids(
190 self, room_id: str, latest_event_ids: Optional[Iterable[str]] = None
191 ) -> StateMap[str]:
149192 """Get the current state, or the state at a set of events, for a room
150193
151194 Args:
152 room_id (str):
153
154 latest_event_ids (iterable[str]|None): if given, the forward
155 extremities to resolve. If None, we look them up from the
156 database (via a cache)
157
158 Returns:
159 Deferred[dict[(str, str), str)]]: the state dict, mapping from
160 (event_type, state_key) -> event_id
195 room_id:
196 latest_event_ids: if given, the forward extremities to resolve. If
197 None, we look them up from the database (via a cache).
198
199 Returns:
200 the state dict, mapping from (event_type, state_key) -> event_id
161201 """
162202 if not latest_event_ids:
163203 latest_event_ids = await self.store.get_latest_event_ids_in_room(room_id)
204 assert latest_event_ids is not None
164205
165206 logger.debug("calling resolve_state_groups from get_current_state_ids")
166207 ret = await self.resolve_state_groups_for_events(room_id, latest_event_ids)
167 state = ret.state
168
169 return state
208 return ret.state
170209
171210 async def get_current_users_in_room(
172211 self, room_id: str, latest_event_ids: Optional[List[str]] = None
182221 """
183222 if not latest_event_ids:
184223 latest_event_ids = await self.store.get_latest_event_ids_in_room(room_id)
224 assert latest_event_ids is not None
225
185226 logger.debug("calling resolve_state_groups from get_current_users_in_room")
186227 entry = await self.resolve_state_groups_for_events(room_id, latest_event_ids)
187 joined_users = await self.store.get_joined_users_from_state(room_id, entry)
188 return joined_users
189
190 async def get_current_hosts_in_room(self, room_id):
228 return await self.store.get_joined_users_from_state(room_id, entry)
229
230 async def get_current_hosts_in_room(self, room_id: str) -> Set[str]:
191231 event_ids = await self.store.get_latest_event_ids_in_room(room_id)
192232 return await self.get_hosts_in_room_at_events(room_id, event_ids)
193233
194 async def get_hosts_in_room_at_events(self, room_id, event_ids):
234 async def get_hosts_in_room_at_events(
235 self, room_id: str, event_ids: List[str]
236 ) -> Set[str]:
195237 """Get the hosts that were in a room at the given event ids
196238
197239 Args:
198 room_id (str):
199 event_ids (list[str]):
200
201 Returns:
202 Deferred[list[str]]: the hosts in the room at the given events
240 room_id:
241 event_ids:
242
243 Returns:
244 The hosts in the room at the given events
203245 """
204246 entry = await self.resolve_state_groups_for_events(room_id, event_ids)
205 joined_hosts = await self.store.get_joined_hosts(room_id, entry)
206 return joined_hosts
247 return await self.store.get_joined_hosts(room_id, entry)
207248
208249 async def compute_event_context(
209250 self, event: EventBase, old_state: Optional[Iterable[EventBase]] = None
210 ):
251 ) -> EventContext:
211252 """Build an EventContext structure for the event.
212253
213254 This works out what the current state should be for the event, and
220261 when receiving an event from federation where we don't have the
221262 prev events for, e.g. when backfilling.
222263 Returns:
223 synapse.events.snapshot.EventContext:
264 The event context.
224265 """
225266
226267 if event.internal_metadata.is_outlier():
261302 # if we're given the state before the event, then we use that
262303 state_ids_before_event = {
263304 (s.type, s.state_key): s.event_id for s in old_state
264 }
305 } # type: StateMap[str]
265306 state_group_before_event = None
266307 state_group_before_event_prev_group = None
267308 deltas_to_state_group_before_event = None
345386 )
346387
347388 @measure_func()
348 async def resolve_state_groups_for_events(self, room_id, event_ids):
389 async def resolve_state_groups_for_events(
390 self, room_id: str, event_ids: Iterable[str]
391 ) -> _StateCacheEntry:
349392 """ Given a list of event_ids this method fetches the state at each
350393 event, resolves conflicts between them and returns them.
351394
352395 Args:
353 room_id (str)
354 event_ids (list[str])
355 explicit_room_version (str|None): If set uses the the given room
356 version to choose the resolution algorithm. If None, then
357 checks the database for room version.
358
359 Returns:
360 Deferred[_StateCacheEntry]: resolved state
396 room_id
397 event_ids
398
399 Returns:
400 The resolved state
361401 """
362402 logger.debug("resolve_state_groups event_ids %s", event_ids)
363403
393433 )
394434 return result
395435
396 async def resolve_events(self, room_version, state_sets, event):
436 async def resolve_events(
437 self,
438 room_version: str,
439 state_sets: Collection[Iterable[EventBase]],
440 event: EventBase,
441 ) -> StateMap[EventBase]:
397442 logger.info(
398443 "Resolving state for %s with %d groups", event.room_id, len(state_sets)
399444 )
413458 state_res_store=StateResolutionStore(self.store),
414459 )
415460
416 new_state = {key: state_map[ev_id] for key, ev_id in new_state.items()}
417
418 return new_state
419
420
421 class StateResolutionHandler(object):
461 return {key: state_map[ev_id] for key, ev_id in new_state.items()}
462
463
464 class StateResolutionHandler:
422465 """Responsible for doing state conflict resolution.
423466
424467 Note that the storage layer depends on this handler, so all functions must
443486
444487 @log_function
445488 async def resolve_state_groups(
446 self, room_id, room_version, state_groups_ids, event_map, state_res_store
489 self,
490 room_id: str,
491 room_version: str,
492 state_groups_ids: Dict[int, StateMap[str]],
493 event_map: Optional[Dict[str, EventBase]],
494 state_res_store: "StateResolutionStore",
447495 ):
448496 """Resolves conflicts between a set of state groups
449497
451499 not be called for a single state group
452500
453501 Args:
454 room_id (str): room we are resolving for (used for logging and sanity checks)
455 room_version (str): version of the room
456 state_groups_ids (dict[int, dict[(str, str), str]]):
457 map from state group id to the state in that state group
502 room_id: room we are resolving for (used for logging and sanity checks)
503 room_version: version of the room
504 state_groups_ids:
505 A map from state group id to the state in that state group
458506 (where 'state' is a map from state key to event id)
459507
460 event_map(dict[str,FrozenEvent]|None):
508 event_map:
461509 a dict from event_id to event, for any events that we happen to
462510 have in flight (eg, those currently being persisted). This will be
463511 used as a starting point fof finding the state we need; any missing
465513
466514 If None, all events will be fetched via state_res_store.
467515
468 state_res_store (StateResolutionStore)
469
470 Returns:
471 _StateCacheEntry: resolved state
516 state_res_store
517
518 Returns:
519 The resolved state
472520 """
473521 logger.debug("resolve_state_groups state_groups %s", state_groups_ids.keys())
474522
492540 #
493541 # XXX: is this actually worthwhile, or should we just let
494542 # resolve_events_with_store do it?
495 new_state = {}
543 new_state = {} # type: MutableStateMap[str]
496544 conflicted_state = False
497545 for st in state_groups_ids.values():
498546 for key, e_id in st.items():
506554 if conflicted_state:
507555 logger.info("Resolving conflicted state for %r", room_id)
508556 with Measure(self.clock, "state._resolve_events"):
509 new_state = await resolve_events_with_store(
510 self.clock,
511 room_id,
512 room_version,
513 list(state_groups_ids.values()),
514 event_map=event_map,
515 state_res_store=state_res_store,
557 # resolve_events_with_store returns a StateMap, but we can
558 # treat it as a MutableStateMap as it is above. It isn't
559 # actually mutated anymore (and is frozen in
560 # _make_state_cache_entry below).
561 new_state = cast(
562 MutableStateMap,
563 await resolve_events_with_store(
564 self.clock,
565 room_id,
566 room_version,
567 list(state_groups_ids.values()),
568 event_map=event_map,
569 state_res_store=state_res_store,
570 ),
516571 )
517572
518573 # if the new state matches any of the input state groups, we can
529584 return cache
530585
531586
532 def _make_state_cache_entry(new_state, state_groups_ids):
587 def _make_state_cache_entry(
588 new_state: StateMap[str], state_groups_ids: Dict[int, StateMap[str]]
589 ) -> _StateCacheEntry:
533590 """Given a resolved state, and a set of input state groups, pick one to base
534591 a new state group on (if any), and return an appropriately-constructed
535592 _StateCacheEntry.
536593
537594 Args:
538 new_state (dict[(str, str), str]): resolved state map (mapping from
539 (type, state_key) to event_id)
540
541 state_groups_ids (dict[int, dict[(str, str), str]]):
542 map from state group id to the state in that state group
543 (where 'state' is a map from state key to event id)
595 new_state: resolved state map (mapping from (type, state_key) to event_id)
596
597 state_groups_ids:
598 map from state group id to the state in that state group (where
599 'state' is a map from state key to event id)
544600
545601 Returns:
546 _StateCacheEntry
602 The cache entry.
547603 """
548604 # if the new state matches any of the input state groups, we can
549605 # use that state group again. Otherwise we will generate a state_id
584640 clock: Clock,
585641 room_id: str,
586642 room_version: str,
587 state_sets: List[StateMap[str]],
643 state_sets: Sequence[StateMap[str]],
588644 event_map: Optional[Dict[str, EventBase]],
589645 state_res_store: "StateResolutionStore",
590646 ) -> Awaitable[StateMap[str]]:
622678
623679
624680 @attr.s
625 class StateResolutionStore(object):
681 class StateResolutionStore:
626682 """Interface that allows state resolution algorithms to access the database
627683 in well defined way.
628684
632688
633689 store = attr.ib()
634690
635 def get_events(self, event_ids, allow_rejected=False):
691 def get_events(
692 self, event_ids: Iterable[str], allow_rejected: bool = False
693 ) -> Awaitable[Dict[str, EventBase]]:
636694 """Get events from the database
637695
638696 Args:
639 event_ids (list): The event_ids of the events to fetch
640 allow_rejected (bool): If True return rejected events.
641
642 Returns:
643 Deferred[dict[str, FrozenEvent]]: Dict from event_id to event.
697 event_ids: The event_ids of the events to fetch
698 allow_rejected: If True return rejected events.
699
700 Returns:
701 An awaitable which resolves to a dict from event_id to event.
644702 """
645703
646704 return self.store.get_events(
650708 allow_rejected=allow_rejected,
651709 )
652710
653 def get_auth_chain_difference(self, state_sets: List[Set[str]]):
711 def get_auth_chain_difference(
712 self, state_sets: List[Set[str]]
713 ) -> Awaitable[Set[str]]:
654714 """Given sets of state events figure out the auth chain difference (as
655715 per state res v2 algorithm).
656716
659719 chain.
660720
661721 Returns:
662 Deferred[Set[str]]: Set of event IDs.
722 An awaitable that resolves to a set of event IDs.
663723 """
664724
665725 return self.store.get_auth_chain_difference(state_sets)
1414
1515 import hashlib
1616 import logging
17 from typing import Awaitable, Callable, Dict, List, Optional
17 from typing import (
18 Awaitable,
19 Callable,
20 Dict,
21 Iterable,
22 List,
23 Optional,
24 Sequence,
25 Set,
26 Tuple,
27 )
1828
1929 from synapse import event_auth
2030 from synapse.api.constants import EventTypes
2131 from synapse.api.errors import AuthError
2232 from synapse.api.room_versions import RoomVersions
2333 from synapse.events import EventBase
24 from synapse.types import StateMap
34 from synapse.types import MutableStateMap, StateMap
2535
2636 logger = logging.getLogger(__name__)
2737
3141
3242 async def resolve_events_with_store(
3343 room_id: str,
34 state_sets: List[StateMap[str]],
44 state_sets: Sequence[StateMap[str]],
3545 event_map: Optional[Dict[str, EventBase]],
36 state_map_factory: Callable[[List[str]], Awaitable],
37 ):
46 state_map_factory: Callable[[Iterable[str]], Awaitable[Dict[str, EventBase]]],
47 ) -> StateMap[str]:
3848 """
3949 Args:
4050 room_id: the room we are working in
5565 an Awaitable that resolves to a dict of event_id to event.
5666
5767 Returns:
58 Deferred[dict[(str, str), str]]:
59 a map from (type, state_key) to event_id.
68 A map from (type, state_key) to event_id.
6069 """
6170 if len(state_sets) == 1:
6271 return state_sets[0]
7483 "Asking for %d/%d conflicted events", len(needed_events), needed_event_count
7584 )
7685
77 # dict[str, FrozenEvent]: a map from state event id to event. Only includes
78 # the state events which are in conflict (and those in event_map)
86 # A map from state event id to event. Only includes the state events which
87 # are in conflict (and those in event_map).
7988 state_map = await state_map_factory(needed_events)
8089 if event_map is not None:
8190 state_map.update(event_map)
9099
91100 # get the ids of the auth events which allow us to authenticate the
92101 # conflicted state, picking only from the unconflicting state.
93 #
94 # dict[(str, str), str]: a map from state key to event id
95102 auth_events = _create_auth_events_from_maps(
96103 unconflicted_state, conflicted_state, state_map
97104 )
121128 )
122129
123130
124 def _seperate(state_sets):
131 def _seperate(
132 state_sets: Iterable[StateMap[str]],
133 ) -> Tuple[MutableStateMap[str], MutableStateMap[Set[str]]]:
125134 """Takes the state_sets and figures out which keys are conflicted and
126135 which aren't. i.e., which have multiple different event_ids associated
127136 with them in different state sets.
128137
129138 Args:
130 state_sets(iterable[dict[(str, str), str]]):
139 state_sets:
131140 List of dicts of (type, state_key) -> event_id, which are the
132141 different state groups to resolve.
133142
134143 Returns:
135 (dict[(str, str), str], dict[(str, str), set[str]]):
136 A tuple of (unconflicted_state, conflicted_state), where:
137
138 unconflicted_state is a dict mapping (type, state_key)->event_id
139 for unconflicted state keys.
140
141 conflicted_state is a dict mapping (type, state_key) to a set of
142 event ids for conflicted state keys.
144 A tuple of (unconflicted_state, conflicted_state), where:
145
146 unconflicted_state is a dict mapping (type, state_key)->event_id
147 for unconflicted state keys.
148
149 conflicted_state is a dict mapping (type, state_key) to a set of
150 event ids for conflicted state keys.
143151 """
144152 state_set_iterator = iter(state_sets)
145153 unconflicted_state = dict(next(state_set_iterator))
146 conflicted_state = {}
154 conflicted_state = {} # type: MutableStateMap[Set[str]]
147155
148156 for state_set in state_set_iterator:
149157 for key, value in state_set.items():
170178 return unconflicted_state, conflicted_state
171179
172180
173 def _create_auth_events_from_maps(unconflicted_state, conflicted_state, state_map):
181 def _create_auth_events_from_maps(
182 unconflicted_state: StateMap[str],
183 conflicted_state: StateMap[Set[str]],
184 state_map: Dict[str, EventBase],
185 ) -> StateMap[str]:
186 """
187
188 Args:
189 unconflicted_state: The unconflicted state map.
190 conflicted_state: The conflicted state map.
191 state_map:
192
193 Returns:
194 A map from state key to event id.
195 """
174196 auth_events = {}
175197 for event_ids in conflicted_state.values():
176198 for event_id in event_ids:
178200 keys = event_auth.auth_types_for_event(state_map[event_id])
179201 for key in keys:
180202 if key not in auth_events:
181 event_id = unconflicted_state.get(key, None)
182 if event_id:
183 auth_events[key] = event_id
203 auth_event_id = unconflicted_state.get(key, None)
204 if auth_event_id:
205 auth_events[key] = auth_event_id
184206 return auth_events
185207
186208
187209 def _resolve_with_state(
188 unconflicted_state_ids, conflicted_state_ids, auth_event_ids, state_map
210 unconflicted_state_ids: MutableStateMap[str],
211 conflicted_state_ids: StateMap[Set[str]],
212 auth_event_ids: StateMap[str],
213 state_map: Dict[str, EventBase],
189214 ):
190215 conflicted_state = {}
191216 for key, event_ids in conflicted_state_ids.items():
214239 return new_state
215240
216241
217 def _resolve_state_events(conflicted_state, auth_events):
242 def _resolve_state_events(
243 conflicted_state: StateMap[List[EventBase]], auth_events: MutableStateMap[EventBase]
244 ) -> StateMap[EventBase]:
218245 """ This is where we actually decide which of the conflicted state to
219246 use.
220247
254281 return resolved_state
255282
256283
257 def _resolve_auth_events(events, auth_events):
284 def _resolve_auth_events(
285 events: List[EventBase], auth_events: StateMap[EventBase]
286 ) -> EventBase:
258287 reverse = list(reversed(_ordered_events(events)))
259288
260289 auth_keys = {
288317 return event
289318
290319
291 def _resolve_normal_events(events, auth_events):
320 def _resolve_normal_events(
321 events: List[EventBase], auth_events: StateMap[EventBase]
322 ) -> EventBase:
292323 for event in _ordered_events(events):
293324 try:
294325 # The signatures have already been checked at this point
308339 return event
309340
310341
311 def _ordered_events(events):
342 def _ordered_events(events: Iterable[EventBase]) -> List[EventBase]:
312343 def key_func(e):
313344 # we have to use utf-8 rather than ascii here because it turns out we allow
314345 # people to send us events with non-ascii event IDs :/
1515 import heapq
1616 import itertools
1717 import logging
18 from typing import Dict, List, Optional
18 from typing import (
19 Any,
20 Callable,
21 Dict,
22 Generator,
23 Iterable,
24 List,
25 Optional,
26 Sequence,
27 Set,
28 Tuple,
29 overload,
30 )
31
32 from typing_extensions import Literal
1933
2034 import synapse.state
2135 from synapse import event_auth
2337 from synapse.api.errors import AuthError
2438 from synapse.api.room_versions import KNOWN_ROOM_VERSIONS
2539 from synapse.events import EventBase
26 from synapse.types import StateMap
40 from synapse.types import MutableStateMap, StateMap
2741 from synapse.util import Clock
2842
2943 logger = logging.getLogger(__name__)
3953 clock: Clock,
4054 room_id: str,
4155 room_version: str,
42 state_sets: List[StateMap[str]],
56 state_sets: Sequence[StateMap[str]],
4357 event_map: Optional[Dict[str, EventBase]],
4458 state_res_store: "synapse.state.StateResolutionStore",
45 ):
59 ) -> StateMap[str]:
4660 """Resolves the state using the v2 state resolution algorithm
4761
4862 Args:
6276 state_res_store:
6377
6478 Returns:
65 Deferred[dict[(str, str), str]]:
66 a map from (type, state_key) to event_id.
79 A map from (type, state_key) to event_id.
6780 """
6881
6982 logger.debug("Computing conflicted state")
170183 return resolved_state
171184
172185
173 async def _get_power_level_for_sender(room_id, event_id, event_map, state_res_store):
186 async def _get_power_level_for_sender(
187 room_id: str,
188 event_id: str,
189 event_map: Dict[str, EventBase],
190 state_res_store: "synapse.state.StateResolutionStore",
191 ) -> int:
174192 """Return the power level of the sender of the given event according to
175193 their auth events.
176194
177195 Args:
178 room_id (str)
179 event_id (str)
180 event_map (dict[str,FrozenEvent])
181 state_res_store (StateResolutionStore)
182
183 Returns:
184 Deferred[int]
196 room_id
197 event_id
198 event_map
199 state_res_store
200
201 Returns:
202 The power level.
185203 """
186204 event = await _get_event(room_id, event_id, event_map, state_res_store)
187205
216234 return int(level)
217235
218236
219 async def _get_auth_chain_difference(state_sets, event_map, state_res_store):
237 async def _get_auth_chain_difference(
238 state_sets: Sequence[StateMap[str]],
239 event_map: Dict[str, EventBase],
240 state_res_store: "synapse.state.StateResolutionStore",
241 ) -> Set[str]:
220242 """Compare the auth chains of each state set and return the set of events
221243 that only appear in some but not all of the auth chains.
222244
223245 Args:
224 state_sets (list)
225 event_map (dict[str,FrozenEvent])
226 state_res_store (StateResolutionStore)
227
228 Returns:
229 Deferred[set[str]]: Set of event IDs
246 state_sets
247 event_map
248 state_res_store
249
250 Returns:
251 Set of event IDs
230252 """
231253
232254 difference = await state_res_store.get_auth_chain_difference(
236258 return difference
237259
238260
239 def _seperate(state_sets):
261 def _seperate(
262 state_sets: Iterable[StateMap[str]],
263 ) -> Tuple[StateMap[str], StateMap[Set[str]]]:
240264 """Return the unconflicted and conflicted state. This is different than in
241265 the original algorithm, as this defines a key to be conflicted if one of
242266 the state sets doesn't have that key.
243267
244268 Args:
245 state_sets (list)
246
247 Returns:
248 tuple[dict, dict]: A tuple of unconflicted and conflicted state. The
249 conflicted state dict is a map from type/state_key to set of event IDs
269 state_sets
270
271 Returns:
272 A tuple of unconflicted and conflicted state. The conflicted state dict
273 is a map from type/state_key to set of event IDs
250274 """
251275 unconflicted_state = {}
252276 conflicted_state = {}
259283 event_ids.discard(None)
260284 conflicted_state[key] = event_ids
261285
262 return unconflicted_state, conflicted_state
263
264
265 def _is_power_event(event):
286 # mypy doesn't understand that discarding None above means that conflicted
287 # state is StateMap[Set[str]], not StateMap[Set[Optional[Str]]].
288 return unconflicted_state, conflicted_state # type: ignore
289
290
291 def _is_power_event(event: EventBase) -> bool:
266292 """Return whether or not the event is a "power event", as defined by the
267293 v2 state resolution algorithm
268294
269295 Args:
270 event (FrozenEvent)
271
272 Returns:
273 boolean
296 event
297
298 Returns:
299 True if the event is a power event.
274300 """
275301 if (event.type, event.state_key) in (
276302 (EventTypes.PowerLevels, ""),
287313
288314
289315 async def _add_event_and_auth_chain_to_graph(
290 graph, room_id, event_id, event_map, state_res_store, auth_diff
291 ):
316 graph: Dict[str, Set[str]],
317 room_id: str,
318 event_id: str,
319 event_map: Dict[str, EventBase],
320 state_res_store: "synapse.state.StateResolutionStore",
321 auth_diff: Set[str],
322 ) -> None:
292323 """Helper function for _reverse_topological_power_sort that add the event
293324 and its auth chain (that is in the auth diff) to the graph
294325
295326 Args:
296 graph (dict[str, set[str]]): A map from event ID to the events auth
297 event IDs
298 room_id (str): the room we are working in
299 event_id (str): Event to add to the graph
300 event_map (dict[str,FrozenEvent])
301 state_res_store (StateResolutionStore)
302 auth_diff (set[str]): Set of event IDs that are in the auth difference.
327 graph: A map from event ID to the events auth event IDs
328 room_id: the room we are working in
329 event_id: Event to add to the graph
330 event_map
331 state_res_store
332 auth_diff: Set of event IDs that are in the auth difference.
303333 """
304334
305335 state = [event_id]
317347
318348
319349 async def _reverse_topological_power_sort(
320 clock, room_id, event_ids, event_map, state_res_store, auth_diff
321 ):
350 clock: Clock,
351 room_id: str,
352 event_ids: Iterable[str],
353 event_map: Dict[str, EventBase],
354 state_res_store: "synapse.state.StateResolutionStore",
355 auth_diff: Set[str],
356 ) -> List[str]:
322357 """Returns a list of the event_ids sorted by reverse topological ordering,
323358 and then by power level and origin_server_ts
324359
325360 Args:
326 clock (Clock)
327 room_id (str): the room we are working in
328 event_ids (list[str]): The events to sort
329 event_map (dict[str,FrozenEvent])
330 state_res_store (StateResolutionStore)
331 auth_diff (set[str]): Set of event IDs that are in the auth difference.
332
333 Returns:
334 Deferred[list[str]]: The sorted list
335 """
336
337 graph = {}
361 clock
362 room_id: the room we are working in
363 event_ids: The events to sort
364 event_map
365 state_res_store
366 auth_diff: Set of event IDs that are in the auth difference.
367
368 Returns:
369 The sorted list
370 """
371
372 graph = {} # type: Dict[str, Set[str]]
338373 for idx, event_id in enumerate(event_ids, start=1):
339374 await _add_event_and_auth_chain_to_graph(
340375 graph, room_id, event_id, event_map, state_res_store, auth_diff
371406
372407
373408 async def _iterative_auth_checks(
374 clock, room_id, room_version, event_ids, base_state, event_map, state_res_store
375 ):
409 clock: Clock,
410 room_id: str,
411 room_version: str,
412 event_ids: List[str],
413 base_state: StateMap[str],
414 event_map: Dict[str, EventBase],
415 state_res_store: "synapse.state.StateResolutionStore",
416 ) -> MutableStateMap[str]:
376417 """Sequentially apply auth checks to each event in given list, updating the
377418 state as it goes along.
378419
379420 Args:
380 clock (Clock)
381 room_id (str)
382 room_version (str)
383 event_ids (list[str]): Ordered list of events to apply auth checks to
384 base_state (StateMap[str]): The set of state to start with
385 event_map (dict[str,FrozenEvent])
386 state_res_store (StateResolutionStore)
387
388 Returns:
389 Deferred[StateMap[str]]: Returns the final updated state
390 """
391 resolved_state = base_state.copy()
421 clock
422 room_id
423 room_version
424 event_ids: Ordered list of events to apply auth checks to
425 base_state: The set of state to start with
426 event_map
427 state_res_store
428
429 Returns:
430 Returns the final updated state
431 """
432 resolved_state = dict(base_state)
392433 room_version_obj = KNOWN_ROOM_VERSIONS[room_version]
393434
394435 for idx, event_id in enumerate(event_ids, start=1):
438479
439480
440481 async def _mainline_sort(
441 clock, room_id, event_ids, resolved_power_event_id, event_map, state_res_store
442 ):
482 clock: Clock,
483 room_id: str,
484 event_ids: List[str],
485 resolved_power_event_id: Optional[str],
486 event_map: Dict[str, EventBase],
487 state_res_store: "synapse.state.StateResolutionStore",
488 ) -> List[str]:
443489 """Returns a sorted list of event_ids sorted by mainline ordering based on
444490 the given event resolved_power_event_id
445491
446492 Args:
447 clock (Clock)
448 room_id (str): room we're working in
449 event_ids (list[str]): Events to sort
450 resolved_power_event_id (str): The final resolved power level event ID
451 event_map (dict[str,FrozenEvent])
452 state_res_store (StateResolutionStore)
453
454 Returns:
455 Deferred[list[str]]: The sorted list
493 clock
494 room_id: room we're working in
495 event_ids: Events to sort
496 resolved_power_event_id: The final resolved power level event ID
497 event_map
498 state_res_store
499
500 Returns:
501 The sorted list
456502 """
457503 if not event_ids:
458504 # It's possible for there to be no event IDs here to sort, so we can
504550
505551
506552 async def _get_mainline_depth_for_event(
507 event, mainline_map, event_map, state_res_store
508 ):
553 event: EventBase,
554 mainline_map: Dict[str, int],
555 event_map: Dict[str, EventBase],
556 state_res_store: "synapse.state.StateResolutionStore",
557 ) -> int:
509558 """Get the mainline depths for the given event based on the mainline map
510559
511560 Args:
512 event (FrozenEvent)
513 mainline_map (dict[str, int]): Map from event_id to mainline depth for
514 events in the mainline.
515 event_map (dict[str,FrozenEvent])
516 state_res_store (StateResolutionStore)
517
518 Returns:
519 Deferred[int]
561 event
562 mainline_map: Map from event_id to mainline depth for events in the mainline.
563 event_map
564 state_res_store
565
566 Returns:
567 The mainline depth
520568 """
521569
522570 room_id = event.room_id
571 tmp_event = event # type: Optional[EventBase]
523572
524573 # We do an iterative search, replacing `event with the power level in its
525574 # auth events (if any)
526 while event:
575 while tmp_event:
527576 depth = mainline_map.get(event.event_id)
528577 if depth is not None:
529578 return depth
530579
531 auth_events = event.auth_event_ids()
532 event = None
580 auth_events = tmp_event.auth_event_ids()
581 tmp_event = None
533582
534583 for aid in auth_events:
535584 aev = await _get_event(
536585 room_id, aid, event_map, state_res_store, allow_none=True
537586 )
538587 if aev and (aev.type, aev.state_key) == (EventTypes.PowerLevels, ""):
539 event = aev
588 tmp_event = aev
540589 break
541590
542591 # Didn't find a power level auth event, so we just return 0
543592 return 0
544593
545594
546 async def _get_event(room_id, event_id, event_map, state_res_store, allow_none=False):
595 @overload
596 async def _get_event(
597 room_id: str,
598 event_id: str,
599 event_map: Dict[str, EventBase],
600 state_res_store: "synapse.state.StateResolutionStore",
601 allow_none: Literal[False] = False,
602 ) -> EventBase:
603 ...
604
605
606 @overload
607 async def _get_event(
608 room_id: str,
609 event_id: str,
610 event_map: Dict[str, EventBase],
611 state_res_store: "synapse.state.StateResolutionStore",
612 allow_none: Literal[True],
613 ) -> Optional[EventBase]:
614 ...
615
616
617 async def _get_event(
618 room_id: str,
619 event_id: str,
620 event_map: Dict[str, EventBase],
621 state_res_store: "synapse.state.StateResolutionStore",
622 allow_none: bool = False,
623 ) -> Optional[EventBase]:
547624 """Helper function to look up event in event_map, falling back to looking
548625 it up in the store
549626
550627 Args:
551 room_id (str)
552 event_id (str)
553 event_map (dict[str,FrozenEvent])
554 state_res_store (StateResolutionStore)
555 allow_none (bool): if the event is not found, return None rather than raising
628 room_id
629 event_id
630 event_map
631 state_res_store
632 allow_none: if the event is not found, return None rather than raising
556633 an exception
557634
558635 Returns:
559 Deferred[Optional[FrozenEvent]]
636 The event, or none if the event does not exist (and allow_none is True).
560637 """
561638 if event_id not in event_map:
562639 events = await state_res_store.get_events([event_id], allow_rejected=True)
576653 return event
577654
578655
579 def lexicographical_topological_sort(graph, key):
656 def lexicographical_topological_sort(
657 graph: Dict[str, Set[str]], key: Callable[[str], Any]
658 ) -> Generator[str, None, None]:
580659 """Performs a lexicographic reverse topological sort on the graph.
581660
582661 This returns a reverse topological sort (i.e. if node A references B then B
586665 NOTE: `graph` is modified during the sort.
587666
588667 Args:
589 graph (dict[str, set[str]]): A representation of the graph where each
590 node is a key in the dict and its value are the nodes edges.
591 key (func): A function that takes a node and returns a value that is
592 comparable and used to order nodes
668 graph: A representation of the graph where each node is a key in the
669 dict and its value are the nodes edges.
670 key: A function that takes a node and returns a value that is comparable
671 and used to order nodes
593672
594673 Yields:
595 str: The next node in the topological sort
674 The next node in the topological sort
596675 """
597676
598677 # Note, this is basically Kahn's algorithm except we look at nodes with no
599678 # outgoing edges, c.f.
600679 # https://en.wikipedia.org/wiki/Topological_sorting#Kahn's_algorithm
601680 outdegree_map = graph
602 reverse_graph = {}
681 reverse_graph = {} # type: Dict[str, Set[str]]
603682
604683 # Lists of nodes with zero out degree. Is actually a tuple of
605684 # `(key(node), node)` so that sorting does the right thing
3636 __all__ = ["DataStores", "DataStore"]
3737
3838
39 class Storage(object):
39 class Storage:
4040 """The high level interfaces for talking to various storage layers.
4141 """
4242
1818 from abc import ABCMeta
1919 from typing import Any, Optional
2020
21 from canonicaljson import json
22
2321 from synapse.storage.database import LoggingTransaction # noqa: F401
2422 from synapse.storage.database import make_in_list_sql_clause # noqa: F401
2523 from synapse.storage.database import DatabasePool
2624 from synapse.types import Collection, get_domain_from_id
25 from synapse.util import json_decoder
2726
2827 logger = logging.getLogger(__name__)
2928
9897 if isinstance(db_content, memoryview):
9998 db_content = db_content.tobytes()
10099
101 # Decode it to a Unicode string before feeding it to json.loads, since
100 # Decode it to a Unicode string before feeding it to the JSON decoder, since
102101 # Python 3.5 does not support deserializing bytes.
103102 if isinstance(db_content, (bytes, bytearray)):
104103 db_content = db_content.decode("utf8")
105104
106105 try:
107 return json.loads(db_content)
106 return json_decoder.decode(db_content)
108107 except Exception:
109108 logging.warning("Tried to decode '%r' as JSON and failed", db_content)
110109 raise
1515 import logging
1616 from typing import Optional
1717
18 from canonicaljson import json
19
20 from twisted.internet import defer
21
2218 from synapse.metrics.background_process_metrics import run_as_background_process
19 from synapse.util import json_encoder
2320
2421 from . import engines
2522
2623 logger = logging.getLogger(__name__)
2724
2825
29 class BackgroundUpdatePerformance(object):
26 class BackgroundUpdatePerformance:
3027 """Tracks the how long a background update is taking to update its items"""
3128
3229 def __init__(self, name):
7370 return float(self.total_item_count) / float(self.total_duration_ms)
7471
7572
76 class BackgroundUpdater(object):
73 class BackgroundUpdater:
7774 """ Background updates are updates to the database that run in the
7875 background. Each update processes a batch of data at once. We attempt to
7976 limit the impact of each update by monitoring how long each batch takes to
307304 update_name (str): Name of update
308305 """
309306
310 @defer.inlineCallbacks
311 def noop_update(progress, batch_size):
312 yield self._end_background_update(update_name)
307 async def noop_update(progress, batch_size):
308 await self._end_background_update(update_name)
313309 return 1
314310
315311 self.register_background_update_handler(update_name, noop_update)
408404 else:
409405 runner = create_index_sqlite
410406
411 @defer.inlineCallbacks
412 def updater(progress, batch_size):
407 async def updater(progress, batch_size):
413408 if runner is not None:
414409 logger.info("Adding index %s to %s", index_name, table)
415 yield self.db_pool.runWithConnection(runner)
416 yield self._end_background_update(update_name)
410 await self.db_pool.runWithConnection(runner)
411 await self._end_background_update(update_name)
417412 return 1
418413
419414 self.register_background_update_handler(update_name, updater)
420415
421 def _end_background_update(self, update_name):
416 async def _end_background_update(self, update_name: str) -> None:
422417 """Removes a completed background update task from the queue.
423418
424419 Args:
425 update_name(str): The name of the completed task to remove
420 update_name:: The name of the completed task to remove
421
426422 Returns:
427 A deferred that completes once the task is removed.
423 None, completes once the task is removed.
428424 """
429425 if update_name != self._current_background_update:
430426 raise Exception(
432428 % update_name
433429 )
434430 self._current_background_update = None
435 return self.db_pool.simple_delete_one(
431 await self.db_pool.simple_delete_one(
436432 "background_updates", keyvalues={"update_name": update_name}
437433 )
438434
439 def _background_update_progress(self, update_name: str, progress: dict):
435 async def _background_update_progress(self, update_name: str, progress: dict):
440436 """Update the progress of a background update
441437
442438 Args:
444440 progress: The progress of the update.
445441 """
446442
447 return self.db_pool.runInteraction(
443 return await self.db_pool.runInteraction(
448444 "background_update_progress",
449445 self._background_update_progress_txn,
450446 update_name,
460456 progress(dict): The progress of the update.
461457 """
462458
463 progress_json = json.dumps(progress)
459 progress_json = json_encoder.encode(progress)
464460
465461 self.db_pool.simple_update_one_txn(
466462 txn,
2727 Optional,
2828 Tuple,
2929 TypeVar,
30 cast,
31 overload,
3032 )
3133
3234 from prometheus_client import Histogram
35 from typing_extensions import Literal
3336
3437 from twisted.enterprise import adbapi
35 from twisted.internet import defer
3638
3739 from synapse.api.errors import StoreError
3840 from synapse.config.database import DatabaseConnectionConfig
124126 method.
125127
126128 Args:
127 txn: The database transcation object to wrap.
129 txn: The database transaction object to wrap.
128130 name: The name of this transactions for logging.
129131 database_engine
130132 after_callbacks: A list that callbacks will be appended to
159161 self.after_callbacks = after_callbacks
160162 self.exception_callbacks = exception_callbacks
161163
162 def call_after(self, callback: "Callable[..., None]", *args, **kwargs):
164 def call_after(self, callback: "Callable[..., None]", *args: Any, **kwargs: Any):
163165 """Call the given callback on the main twisted thread after the
164166 transaction has finished. Used to invalidate the caches on the
165167 correct thread.
170172 assert self.after_callbacks is not None
171173 self.after_callbacks.append((callback, args, kwargs))
172174
173 def call_on_exception(self, callback: "Callable[..., None]", *args, **kwargs):
175 def call_on_exception(
176 self, callback: "Callable[..., None]", *args: Any, **kwargs: Any
177 ):
174178 # if self.exception_callbacks is None, that means that whatever constructed the
175179 # LoggingTransaction isn't expecting there to be any callbacks; assert that
176180 # is not the case.
194198 def description(self) -> Any:
195199 return self.txn.description
196200
197 def execute_batch(self, sql, args):
201 def execute_batch(self, sql: str, args: Iterable[Iterable[Any]]) -> None:
198202 if isinstance(self.database_engine, PostgresEngine):
199203 from psycopg2.extras import execute_batch # type: ignore
200204
203207 for val in args:
204208 self.execute(sql, val)
205209
206 def execute(self, sql: str, *args: Any):
210 def execute(self, sql: str, *args: Any) -> None:
207211 self._do_execute(self.txn.execute, sql, *args)
208212
209 def executemany(self, sql: str, *args: Any):
213 def executemany(self, sql: str, *args: Any) -> None:
210214 self._do_execute(self.txn.executemany, sql, *args)
211215
212216 def _make_sql_one_line(self, sql: str) -> str:
213217 "Strip newlines out of SQL so that the loggers in the DB are on one line"
214218 return " ".join(line.strip() for line in sql.splitlines() if line.strip())
215219
216 def _do_execute(self, func, sql, *args):
220 def _do_execute(self, func, sql: str, *args: Any) -> None:
217221 sql = self._make_sql_one_line(sql)
218222
219223 # TODO(paul): Maybe use 'info' and 'debug' for values?
239243 sql_logger.debug("[SQL time] {%s} %f sec", self.name, secs)
240244 sql_query_timer.labels(sql.split()[0]).observe(secs)
241245
242 def close(self):
246 def close(self) -> None:
243247 self.txn.close()
244248
245249
246 class PerformanceCounters(object):
250 class PerformanceCounters:
247251 def __init__(self):
248252 self.current_counters = {}
249253 self.previous_counters = {}
250254
251 def update(self, key, duration_secs):
255 def update(self, key: str, duration_secs: float) -> None:
252256 count, cum_time = self.current_counters.get(key, (0, 0))
253257 count += 1
254258 cum_time += duration_secs
255259 self.current_counters[key] = (count, cum_time)
256260
257 def interval(self, interval_duration_secs, limit=3):
261 def interval(self, interval_duration_secs: float, limit: int = 3) -> str:
258262 counters = []
259263 for name, (count, cum_time) in self.current_counters.items():
260264 prev_count, prev_time = self.previous_counters.get(name, (0, 0))
278282 return top_n_counters
279283
280284
281 class DatabasePool(object):
285 R = TypeVar("R")
286
287
288 class DatabasePool:
282289 """Wraps a single physical database and connection pool.
283290
284291 A single database may be used by multiple data stores.
326333 self._check_safe_to_upsert,
327334 )
328335
329 def is_running(self):
336 def is_running(self) -> bool:
330337 """Is the database pool currently running
331338 """
332339 return self._db_pool.running
333340
334 @defer.inlineCallbacks
335 def _check_safe_to_upsert(self):
341 async def _check_safe_to_upsert(self) -> None:
336342 """
337343 Is it safe to use native UPSERT?
338344
341347
342348 If the background updates have not completed, wait 15 sec and check again.
343349 """
344 updates = yield self.simple_select_list(
350 updates = await self.simple_select_list(
345351 "background_updates",
346352 keyvalues=None,
347353 retcols=["update_name"],
363369 self._check_safe_to_upsert,
364370 )
365371
366 def start_profiling(self):
372 def start_profiling(self) -> None:
367373 self._previous_loop_ts = monotonic_time()
368374
369375 def loop():
387393 self._clock.looping_call(loop, 10000)
388394
389395 def new_transaction(
390 self, conn, desc, after_callbacks, exception_callbacks, func, *args, **kwargs
391 ):
396 self,
397 conn: Connection,
398 desc: str,
399 after_callbacks: List[_CallbackListEntry],
400 exception_callbacks: List[_CallbackListEntry],
401 func: "Callable[..., R]",
402 *args: Any,
403 **kwargs: Any
404 ) -> R:
392405 start = monotonic_time()
393406 txn_id = self._TXN_ID
394407
493506 self._txn_perf_counters.update(desc, duration)
494507 sql_txn_timer.labels(desc).observe(duration)
495508
496 @defer.inlineCallbacks
497 def runInteraction(self, desc: str, func: Callable, *args: Any, **kwargs: Any):
509 async def runInteraction(
510 self, desc: str, func: "Callable[..., R]", *args: Any, **kwargs: Any
511 ) -> R:
498512 """Starts a transaction on the database and runs a given function
499513
500514 Arguments:
507521 kwargs: named args to pass to `func`
508522
509523 Returns:
510 Deferred: The result of func
524 The result of func
511525 """
512526 after_callbacks = [] # type: List[_CallbackListEntry]
513527 exception_callbacks = [] # type: List[_CallbackListEntry]
516530 logger.warning("Starting db txn '%s' from sentinel context", desc)
517531
518532 try:
519 result = yield self.runWithConnection(
533 result = await self.runWithConnection(
520534 self.new_transaction,
521535 desc,
522536 after_callbacks,
533547 after_callback(*after_args, **after_kwargs)
534548 raise
535549
536 return result
537
538 @defer.inlineCallbacks
539 def runWithConnection(self, func: Callable, *args: Any, **kwargs: Any):
550 return cast(R, result)
551
552 async def runWithConnection(
553 self, func: "Callable[..., R]", *args: Any, **kwargs: Any
554 ) -> R:
540555 """Wraps the .runWithConnection() method on the underlying db_pool.
541556
542557 Arguments:
547562 kwargs: named args to pass to `func`
548563
549564 Returns:
550 Deferred: The result of func
565 The result of func
551566 """
552567 parent_context = current_context() # type: Optional[LoggingContextOrSentinel]
553568 if not parent_context:
570585
571586 return func(conn, *args, **kwargs)
572587
573 result = yield make_deferred_yieldable(
588 return await make_deferred_yieldable(
574589 self._db_pool.runWithConnection(inner_func, *args, **kwargs)
575590 )
576591
577 return result
578
579592 @staticmethod
580 def cursor_to_dict(cursor):
593 def cursor_to_dict(cursor: Cursor) -> List[Dict[str, Any]]:
581594 """Converts a SQL cursor into an list of dicts.
582595
583596 Args:
584 cursor : The DBAPI cursor which has executed a query.
597 cursor: The DBAPI cursor which has executed a query.
585598 Returns:
586599 A list of dicts where the key is the column header.
587600 """
589602 results = [dict(zip(col_headers, row)) for row in cursor]
590603 return results
591604
592 def execute(self, desc, decoder, query, *args):
605 @overload
606 async def execute(
607 self, desc: str, decoder: Literal[None], query: str, *args: Any
608 ) -> List[Tuple[Any, ...]]:
609 ...
610
611 @overload
612 async def execute(
613 self, desc: str, decoder: Callable[[Cursor], R], query: str, *args: Any
614 ) -> R:
615 ...
616
617 async def execute(
618 self,
619 desc: str,
620 decoder: Optional[Callable[[Cursor], R]],
621 query: str,
622 *args: Any
623 ) -> R:
593624 """Runs a single query for a result set.
594625
595626 Args:
627 desc: description of the transaction, for logging and metrics
596628 decoder - The function which can resolve the cursor results to
597629 something meaningful.
598630 query - The query string to execute
608640 else:
609641 return txn.fetchall()
610642
611 return self.runInteraction(desc, interaction)
643 return await self.runInteraction(desc, interaction)
612644
613645 # "Simple" SQL API methods that operate on a single table with no JOINs,
614646 # no complex WHERE clauses, just a dict of values for columns.
615647
616 @defer.inlineCallbacks
617 def simple_insert(self, table, values, or_ignore=False, desc="simple_insert"):
648 async def simple_insert(
649 self,
650 table: str,
651 values: Dict[str, Any],
652 or_ignore: bool = False,
653 desc: str = "simple_insert",
654 ) -> bool:
618655 """Executes an INSERT query on the named table.
619656
620657 Args:
621 table : string giving the table name
622 values : dict of new column names and values for them
623 or_ignore : bool stating whether an exception should be raised
658 table: string giving the table name
659 values: dict of new column names and values for them
660 or_ignore: bool stating whether an exception should be raised
624661 when a conflicting row already exists. If True, False will be
625662 returned by the function instead
626 desc : string giving a description of the transaction
663 desc: description of the transaction, for logging and metrics
627664
628665 Returns:
629 bool: Whether the row was inserted or not. Only useful when
630 `or_ignore` is True
666 Whether the row was inserted or not. Only useful when `or_ignore` is True
631667 """
632668 try:
633 yield self.runInteraction(desc, self.simple_insert_txn, table, values)
669 await self.runInteraction(desc, self.simple_insert_txn, table, values)
634670 except self.engine.module.IntegrityError:
635671 # We have to do or_ignore flag at this layer, since we can't reuse
636672 # a cursor after we receive an error from the db.
640676 return True
641677
642678 @staticmethod
643 def simple_insert_txn(txn, table, values):
679 def simple_insert_txn(
680 txn: LoggingTransaction, table: str, values: Dict[str, Any]
681 ) -> None:
644682 keys, vals = zip(*values.items())
645683
646684 sql = "INSERT INTO %s (%s) VALUES(%s)" % (
651689
652690 txn.execute(sql, vals)
653691
654 def simple_insert_many(self, table, values, desc):
655 return self.runInteraction(desc, self.simple_insert_many_txn, table, values)
692 async def simple_insert_many(
693 self, table: str, values: List[Dict[str, Any]], desc: str
694 ) -> None:
695 """Executes an INSERT query on the named table.
696
697 Args:
698 table: string giving the table name
699 values: dict of new column names and values for them
700 desc: description of the transaction, for logging and metrics
701 """
702 await self.runInteraction(desc, self.simple_insert_many_txn, table, values)
656703
657704 @staticmethod
658 def simple_insert_many_txn(txn, table, values):
705 def simple_insert_many_txn(
706 txn: LoggingTransaction, table: str, values: List[Dict[str, Any]]
707 ) -> None:
708 """Executes an INSERT query on the named table.
709
710 Args:
711 txn: The transaction to use.
712 table: string giving the table name
713 values: dict of new column names and values for them
714 """
659715 if not values:
660716 return
661717
683739
684740 txn.executemany(sql, vals)
685741
686 @defer.inlineCallbacks
687 def simple_upsert(
688 self,
689 table,
690 keyvalues,
691 values,
692 insertion_values={},
693 desc="simple_upsert",
694 lock=True,
695 ):
742 async def simple_upsert(
743 self,
744 table: str,
745 keyvalues: Dict[str, Any],
746 values: Dict[str, Any],
747 insertion_values: Dict[str, Any] = {},
748 desc: str = "simple_upsert",
749 lock: bool = True,
750 ) -> Optional[bool]:
696751 """
697752
698753 `lock` should generally be set to True (the default), but can be set
706761 this table.
707762
708763 Args:
709 table (str): The table to upsert into
710 keyvalues (dict): The unique key columns and their new values
711 values (dict): The nonunique columns and their new values
712 insertion_values (dict): additional key/values to use only when
713 inserting
714 lock (bool): True to lock the table when doing the upsert.
764 table: The table to upsert into
765 keyvalues: The unique key columns and their new values
766 values: The nonunique columns and their new values
767 insertion_values: additional key/values to use only when inserting
768 desc: description of the transaction, for logging and metrics
769 lock: True to lock the table when doing the upsert.
715770 Returns:
716 Deferred(None or bool): Native upserts always return None. Emulated
717 upserts return True if a new entry was created, False if an existing
718 one was updated.
771 Native upserts always return None. Emulated upserts return True if a
772 new entry was created, False if an existing one was updated.
719773 """
720774 attempts = 0
721775 while True:
722776 try:
723 result = yield self.runInteraction(
777 return await self.runInteraction(
724778 desc,
725779 self.simple_upsert_txn,
726780 table,
729783 insertion_values,
730784 lock=lock,
731785 )
732 return result
733786 except self.engine.module.IntegrityError as e:
734787 attempts += 1
735788 if attempts >= 5:
743796 )
744797
745798 def simple_upsert_txn(
746 self, txn, table, keyvalues, values, insertion_values={}, lock=True
747 ):
799 self,
800 txn: LoggingTransaction,
801 table: str,
802 keyvalues: Dict[str, Any],
803 values: Dict[str, Any],
804 insertion_values: Dict[str, Any] = {},
805 lock: bool = True,
806 ) -> Optional[bool]:
748807 """
749808 Pick the UPSERT method which works best on the platform. Either the
750809 native one (Pg9.5+, recent SQLites), or fall back to an emulated method.
751810
752811 Args:
753812 txn: The transaction to use.
754 table (str): The table to upsert into
755 keyvalues (dict): The unique key tables and their new values
756 values (dict): The nonunique columns and their new values
757 insertion_values (dict): additional key/values to use only when
758 inserting
759 lock (bool): True to lock the table when doing the upsert.
813 table: The table to upsert into
814 keyvalues: The unique key tables and their new values
815 values: The nonunique columns and their new values
816 insertion_values: additional key/values to use only when inserting
817 lock: True to lock the table when doing the upsert.
760818 Returns:
761 None or bool: Native upserts always return None. Emulated
762 upserts return True if a new entry was created, False if an existing
763 one was updated.
819 Native upserts always return None. Emulated upserts return True if a
820 new entry was created, False if an existing one was updated.
764821 """
765822 if self.engine.can_native_upsert and table not in self._unsafe_to_upsert_tables:
766 return self.simple_upsert_txn_native_upsert(
823 self.simple_upsert_txn_native_upsert(
767824 txn, table, keyvalues, values, insertion_values=insertion_values
768825 )
826 return None
769827 else:
770828 return self.simple_upsert_txn_emulated(
771829 txn,
777835 )
778836
779837 def simple_upsert_txn_emulated(
780 self, txn, table, keyvalues, values, insertion_values={}, lock=True
781 ):
782 """
783 Args:
784 table (str): The table to upsert into
785 keyvalues (dict): The unique key tables and their new values
786 values (dict): The nonunique columns and their new values
787 insertion_values (dict): additional key/values to use only when
788 inserting
789 lock (bool): True to lock the table when doing the upsert.
838 self,
839 txn: LoggingTransaction,
840 table: str,
841 keyvalues: Dict[str, Any],
842 values: Dict[str, Any],
843 insertion_values: Dict[str, Any] = {},
844 lock: bool = True,
845 ) -> bool:
846 """
847 Args:
848 table: The table to upsert into
849 keyvalues: The unique key tables and their new values
850 values: The nonunique columns and their new values
851 insertion_values: additional key/values to use only when inserting
852 lock: True to lock the table when doing the upsert.
790853 Returns:
791 bool: Return True if a new entry was created, False if an existing
854 Returns True if a new entry was created, False if an existing
792855 one was updated.
793856 """
794857 # We need to lock the table :(, unless we're *really* careful
846909 return True
847910
848911 def simple_upsert_txn_native_upsert(
849 self, txn, table, keyvalues, values, insertion_values={}
850 ):
912 self,
913 txn: LoggingTransaction,
914 table: str,
915 keyvalues: Dict[str, Any],
916 values: Dict[str, Any],
917 insertion_values: Dict[str, Any] = {},
918 ) -> None:
851919 """
852920 Use the native UPSERT functionality in recent PostgreSQL versions.
853921
854922 Args:
855 table (str): The table to upsert into
856 keyvalues (dict): The unique key tables and their new values
857 values (dict): The nonunique columns and their new values
858 insertion_values (dict): additional key/values to use only when
859 inserting
860 Returns:
861 None
923 table: The table to upsert into
924 keyvalues: The unique key tables and their new values
925 values: The nonunique columns and their new values
926 insertion_values: additional key/values to use only when inserting
862927 """
863928 allvalues = {} # type: Dict[str, Any]
864929 allvalues.update(keyvalues)
9881053
9891054 return txn.execute_batch(sql, args)
9901055
991 def simple_select_one(
992 self, table, keyvalues, retcols, allow_none=False, desc="simple_select_one"
993 ):
1056 @overload
1057 async def simple_select_one(
1058 self,
1059 table: str,
1060 keyvalues: Dict[str, Any],
1061 retcols: Iterable[str],
1062 allow_none: Literal[False] = False,
1063 desc: str = "simple_select_one",
1064 ) -> Dict[str, Any]:
1065 ...
1066
1067 @overload
1068 async def simple_select_one(
1069 self,
1070 table: str,
1071 keyvalues: Dict[str, Any],
1072 retcols: Iterable[str],
1073 allow_none: Literal[True] = True,
1074 desc: str = "simple_select_one",
1075 ) -> Optional[Dict[str, Any]]:
1076 ...
1077
1078 async def simple_select_one(
1079 self,
1080 table: str,
1081 keyvalues: Dict[str, Any],
1082 retcols: Iterable[str],
1083 allow_none: bool = False,
1084 desc: str = "simple_select_one",
1085 ) -> Optional[Dict[str, Any]]:
9941086 """Executes a SELECT query on the named table, which is expected to
9951087 return a single row, returning multiple columns from it.
9961088
9971089 Args:
998 table : string giving the table name
999 keyvalues : dict of column names and values to select the row with
1000 retcols : list of strings giving the names of the columns to return
1001
1002 allow_none : If true, return None instead of failing if the SELECT
1003 statement returns no rows
1004 """
1005 return self.runInteraction(
1090 table: string giving the table name
1091 keyvalues: dict of column names and values to select the row with
1092 retcols: list of strings giving the names of the columns to return
1093 allow_none: If true, return None instead of failing if the SELECT
1094 statement returns no rows
1095 desc: description of the transaction, for logging and metrics
1096 """
1097 return await self.runInteraction(
10061098 desc, self.simple_select_one_txn, table, keyvalues, retcols, allow_none
10071099 )
10081100
1009 def simple_select_one_onecol(
1010 self,
1011 table,
1012 keyvalues,
1013 retcol,
1014 allow_none=False,
1015 desc="simple_select_one_onecol",
1016 ):
1101 @overload
1102 async def simple_select_one_onecol(
1103 self,
1104 table: str,
1105 keyvalues: Dict[str, Any],
1106 retcol: str,
1107 allow_none: Literal[False] = False,
1108 desc: str = "simple_select_one_onecol",
1109 ) -> Any:
1110 ...
1111
1112 @overload
1113 async def simple_select_one_onecol(
1114 self,
1115 table: str,
1116 keyvalues: Dict[str, Any],
1117 retcol: str,
1118 allow_none: Literal[True] = True,
1119 desc: str = "simple_select_one_onecol",
1120 ) -> Optional[Any]:
1121 ...
1122
1123 async def simple_select_one_onecol(
1124 self,
1125 table: str,
1126 keyvalues: Dict[str, Any],
1127 retcol: str,
1128 allow_none: bool = False,
1129 desc: str = "simple_select_one_onecol",
1130 ) -> Optional[Any]:
10171131 """Executes a SELECT query on the named table, which is expected to
10181132 return a single row, returning a single column from it.
10191133
10201134 Args:
1021 table : string giving the table name
1022 keyvalues : dict of column names and values to select the row with
1023 retcol : string giving the name of the column to return
1024 """
1025 return self.runInteraction(
1135 table: string giving the table name
1136 keyvalues: dict of column names and values to select the row with
1137 retcol: string giving the name of the column to return
1138 allow_none: If true, return None instead of failing if the SELECT
1139 statement returns no rows
1140 desc: description of the transaction, for logging and metrics
1141 """
1142 return await self.runInteraction(
10261143 desc,
10271144 self.simple_select_one_onecol_txn,
10281145 table,
10311148 allow_none=allow_none,
10321149 )
10331150
1151 @overload
10341152 @classmethod
10351153 def simple_select_one_onecol_txn(
1036 cls, txn, table, keyvalues, retcol, allow_none=False
1037 ):
1154 cls,
1155 txn: LoggingTransaction,
1156 table: str,
1157 keyvalues: Dict[str, Any],
1158 retcol: str,
1159 allow_none: Literal[False] = False,
1160 ) -> Any:
1161 ...
1162
1163 @overload
1164 @classmethod
1165 def simple_select_one_onecol_txn(
1166 cls,
1167 txn: LoggingTransaction,
1168 table: str,
1169 keyvalues: Dict[str, Any],
1170 retcol: str,
1171 allow_none: Literal[True] = True,
1172 ) -> Optional[Any]:
1173 ...
1174
1175 @classmethod
1176 def simple_select_one_onecol_txn(
1177 cls,
1178 txn: LoggingTransaction,
1179 table: str,
1180 keyvalues: Dict[str, Any],
1181 retcol: str,
1182 allow_none: bool = False,
1183 ) -> Optional[Any]:
10381184 ret = cls.simple_select_onecol_txn(
10391185 txn, table=table, keyvalues=keyvalues, retcol=retcol
10401186 )
10481194 raise StoreError(404, "No row found")
10491195
10501196 @staticmethod
1051 def simple_select_onecol_txn(txn, table, keyvalues, retcol):
1197 def simple_select_onecol_txn(
1198 txn: LoggingTransaction, table: str, keyvalues: Dict[str, Any], retcol: str,
1199 ) -> List[Any]:
10521200 sql = ("SELECT %(retcol)s FROM %(table)s") % {"retcol": retcol, "table": table}
10531201
10541202 if keyvalues:
10591207
10601208 return [r[0] for r in txn]
10611209
1062 def simple_select_onecol(
1063 self, table, keyvalues, retcol, desc="simple_select_onecol"
1064 ):
1210 async def simple_select_onecol(
1211 self,
1212 table: str,
1213 keyvalues: Optional[Dict[str, Any]],
1214 retcol: str,
1215 desc: str = "simple_select_onecol",
1216 ) -> List[Any]:
10651217 """Executes a SELECT query on the named table, which returns a list
10661218 comprising of the values of the named column from the selected rows.
10671219
10681220 Args:
1069 table (str): table name
1070 keyvalues (dict|None): column names and values to select the rows with
1071 retcol (str): column whos value we wish to retrieve.
1221 table: table name
1222 keyvalues: column names and values to select the rows with
1223 retcol: column whos value we wish to retrieve.
1224 desc: description of the transaction, for logging and metrics
10721225
10731226 Returns:
1074 Deferred: Results in a list
1075 """
1076 return self.runInteraction(
1227 Results in a list
1228 """
1229 return await self.runInteraction(
10771230 desc, self.simple_select_onecol_txn, table, keyvalues, retcol
10781231 )
10791232
1080 def simple_select_list(self, table, keyvalues, retcols, desc="simple_select_list"):
1233 async def simple_select_list(
1234 self,
1235 table: str,
1236 keyvalues: Optional[Dict[str, Any]],
1237 retcols: Iterable[str],
1238 desc: str = "simple_select_list",
1239 ) -> List[Dict[str, Any]]:
10811240 """Executes a SELECT query on the named table, which may return zero or
10821241 more rows, returning the result as a list of dicts.
10831242
10841243 Args:
1085 table (str): the table name
1086 keyvalues (dict[str, Any] | None):
1244 table: the table name
1245 keyvalues:
10871246 column names and values to select the rows with, or None to not
10881247 apply a WHERE clause.
1089 retcols (iterable[str]): the names of the columns to return
1248 retcols: the names of the columns to return
1249 desc: description of the transaction, for logging and metrics
1250
10901251 Returns:
1091 defer.Deferred: resolves to list[dict[str, Any]]
1092 """
1093 return self.runInteraction(
1252 A list of dictionaries.
1253 """
1254 return await self.runInteraction(
10941255 desc, self.simple_select_list_txn, table, keyvalues, retcols
10951256 )
10961257
10971258 @classmethod
1098 def simple_select_list_txn(cls, txn, table, keyvalues, retcols):
1259 def simple_select_list_txn(
1260 cls,
1261 txn: LoggingTransaction,
1262 table: str,
1263 keyvalues: Optional[Dict[str, Any]],
1264 retcols: Iterable[str],
1265 ) -> List[Dict[str, Any]]:
10991266 """Executes a SELECT query on the named table, which may return zero or
11001267 more rows, returning the result as a list of dicts.
11011268
11021269 Args:
1103 txn : Transaction object
1104 table (str): the table name
1105 keyvalues (dict[str, T] | None):
1270 txn: Transaction object
1271 table: the table name
1272 keyvalues:
11061273 column names and values to select the rows with, or None to not
11071274 apply a WHERE clause.
1108 retcols (iterable[str]): the names of the columns to return
1275 retcols: the names of the columns to return
11091276 """
11101277 if keyvalues:
11111278 sql = "SELECT %s FROM %s WHERE %s" % (
11201287
11211288 return cls.cursor_to_dict(txn)
11221289
1123 @defer.inlineCallbacks
1124 def simple_select_many_batch(
1125 self,
1126 table,
1127 column,
1128 iterable,
1129 retcols,
1130 keyvalues={},
1131 desc="simple_select_many_batch",
1132 batch_size=100,
1133 ):
1290 async def simple_select_many_batch(
1291 self,
1292 table: str,
1293 column: str,
1294 iterable: Iterable[Any],
1295 retcols: Iterable[str],
1296 keyvalues: Dict[str, Any] = {},
1297 desc: str = "simple_select_many_batch",
1298 batch_size: int = 100,
1299 ) -> List[Any]:
11341300 """Executes a SELECT query on the named table, which may return zero or
11351301 more rows, returning the result as a list of dicts.
11361302
1137 Filters rows by if value of `column` is in `iterable`.
1138
1139 Args:
1140 table : string giving the table name
1141 column : column name to test for inclusion against `iterable`
1142 iterable : list
1143 keyvalues : dict of column names and values to select the rows with
1144 retcols : list of strings giving the names of the columns to return
1303 Filters rows by whether the value of `column` is in `iterable`.
1304
1305 Args:
1306 table: string giving the table name
1307 column: column name to test for inclusion against `iterable`
1308 iterable: list
1309 retcols: list of strings giving the names of the columns to return
1310 keyvalues: dict of column names and values to select the rows with
1311 desc: description of the transaction, for logging and metrics
1312 batch_size: the number of rows for each select query
11451313 """
11461314 results = [] # type: List[Dict[str, Any]]
11471315
11551323 it_list[i : i + batch_size] for i in range(0, len(it_list), batch_size)
11561324 ]
11571325 for chunk in chunks:
1158 rows = yield self.runInteraction(
1326 rows = await self.runInteraction(
11591327 desc,
11601328 self.simple_select_many_txn,
11611329 table,
11701338 return results
11711339
11721340 @classmethod
1173 def simple_select_many_txn(cls, txn, table, column, iterable, keyvalues, retcols):
1341 def simple_select_many_txn(
1342 cls,
1343 txn: LoggingTransaction,
1344 table: str,
1345 column: str,
1346 iterable: Iterable[Any],
1347 keyvalues: Dict[str, Any],
1348 retcols: Iterable[str],
1349 ) -> List[Dict[str, Any]]:
11741350 """Executes a SELECT query on the named table, which may return zero or
11751351 more rows, returning the result as a list of dicts.
11761352
1177 Filters rows by if value of `column` is in `iterable`.
1178
1179 Args:
1180 txn : Transaction object
1181 table : string giving the table name
1182 column : column name to test for inclusion against `iterable`
1183 iterable : list
1184 keyvalues : dict of column names and values to select the rows with
1185 retcols : list of strings giving the names of the columns to return
1353 Filters rows by whether the value of `column` is in `iterable`.
1354
1355 Args:
1356 txn: Transaction object
1357 table: string giving the table name
1358 column: column name to test for inclusion against `iterable`
1359 iterable: list
1360 keyvalues: dict of column names and values to select the rows with
1361 retcols: list of strings giving the names of the columns to return
11861362 """
11871363 if not iterable:
11881364 return []
12031379 txn.execute(sql, values)
12041380 return cls.cursor_to_dict(txn)
12051381
1206 def simple_update(self, table, keyvalues, updatevalues, desc):
1207 return self.runInteraction(
1382 async def simple_update(
1383 self,
1384 table: str,
1385 keyvalues: Dict[str, Any],
1386 updatevalues: Dict[str, Any],
1387 desc: str,
1388 ) -> int:
1389 return await self.runInteraction(
12081390 desc, self.simple_update_txn, table, keyvalues, updatevalues
12091391 )
12101392
12111393 @staticmethod
1212 def simple_update_txn(txn, table, keyvalues, updatevalues):
1394 def simple_update_txn(
1395 txn: LoggingTransaction,
1396 table: str,
1397 keyvalues: Dict[str, Any],
1398 updatevalues: Dict[str, Any],
1399 ) -> int:
12131400 if keyvalues:
12141401 where = "WHERE %s" % " AND ".join("%s = ?" % k for k in keyvalues.keys())
12151402 else:
12251412
12261413 return txn.rowcount
12271414
1228 def simple_update_one(
1229 self, table, keyvalues, updatevalues, desc="simple_update_one"
1230 ):
1415 async def simple_update_one(
1416 self,
1417 table: str,
1418 keyvalues: Dict[str, Any],
1419 updatevalues: Dict[str, Any],
1420 desc: str = "simple_update_one",
1421 ) -> None:
12311422 """Executes an UPDATE query on the named table, setting new values for
12321423 columns in a row matching the key values.
12331424
12341425 Args:
1235 table : string giving the table name
1236 keyvalues : dict of column names and values to select the row with
1237 updatevalues : dict giving column names and values to update
1238 retcols : optional list of column names to return
1239
1240 If present, retcols gives a list of column names on which to perform
1241 a SELECT statement *before* performing the UPDATE statement. The values
1242 of these will be returned in a dict.
1243
1244 These are performed within the same transaction, allowing an atomic
1245 get-and-set. This can be used to implement compare-and-set by putting
1246 the update column in the 'keyvalues' dict as well.
1247 """
1248 return self.runInteraction(
1426 table: string giving the table name
1427 keyvalues: dict of column names and values to select the row with
1428 updatevalues: dict giving column names and values to update
1429 desc: description of the transaction, for logging and metrics
1430 """
1431 await self.runInteraction(
12491432 desc, self.simple_update_one_txn, table, keyvalues, updatevalues
12501433 )
12511434
12521435 @classmethod
1253 def simple_update_one_txn(cls, txn, table, keyvalues, updatevalues):
1436 def simple_update_one_txn(
1437 cls,
1438 txn: LoggingTransaction,
1439 table: str,
1440 keyvalues: Dict[str, Any],
1441 updatevalues: Dict[str, Any],
1442 ) -> None:
12541443 rowcount = cls.simple_update_txn(txn, table, keyvalues, updatevalues)
12551444
12561445 if rowcount == 0:
12581447 if rowcount > 1:
12591448 raise StoreError(500, "More than one row matched (%s)" % (table,))
12601449
1450 # Ideally we could use the overload decorator here to specify that the
1451 # return type is only optional if allow_none is True, but this does not work
1452 # when you call a static method from an instance.
1453 # See https://github.com/python/mypy/issues/7781
12611454 @staticmethod
1262 def simple_select_one_txn(txn, table, keyvalues, retcols, allow_none=False):
1455 def simple_select_one_txn(
1456 txn: LoggingTransaction,
1457 table: str,
1458 keyvalues: Dict[str, Any],
1459 retcols: Iterable[str],
1460 allow_none: bool = False,
1461 ) -> Optional[Dict[str, Any]]:
12631462 select_sql = "SELECT %s FROM %s WHERE %s" % (
12641463 ", ".join(retcols),
12651464 table,
12781477
12791478 return dict(zip(retcols, row))
12801479
1281 def simple_delete_one(self, table, keyvalues, desc="simple_delete_one"):
1480 async def simple_delete_one(
1481 self, table: str, keyvalues: Dict[str, Any], desc: str = "simple_delete_one"
1482 ) -> None:
12821483 """Executes a DELETE query on the named table, expecting to delete a
12831484 single row.
12841485
12851486 Args:
1286 table : string giving the table name
1287 keyvalues : dict of column names and values to select the row with
1288 """
1289 return self.runInteraction(desc, self.simple_delete_one_txn, table, keyvalues)
1487 table: string giving the table name
1488 keyvalues: dict of column names and values to select the row with
1489 desc: description of the transaction, for logging and metrics
1490 """
1491 await self.runInteraction(desc, self.simple_delete_one_txn, table, keyvalues)
12901492
12911493 @staticmethod
1292 def simple_delete_one_txn(txn, table, keyvalues):
1494 def simple_delete_one_txn(
1495 txn: LoggingTransaction, table: str, keyvalues: Dict[str, Any]
1496 ) -> None:
12931497 """Executes a DELETE query on the named table, expecting to delete a
12941498 single row.
12951499
12961500 Args:
1297 table : string giving the table name
1298 keyvalues : dict of column names and values to select the row with
1501 table: string giving the table name
1502 keyvalues: dict of column names and values to select the row with
12991503 """
13001504 sql = "DELETE FROM %s WHERE %s" % (
13011505 table,
13081512 if txn.rowcount > 1:
13091513 raise StoreError(500, "More than one row matched (%s)" % (table,))
13101514
1311 def simple_delete(self, table, keyvalues, desc):
1312 return self.runInteraction(desc, self.simple_delete_txn, table, keyvalues)
1515 async def simple_delete(
1516 self, table: str, keyvalues: Dict[str, Any], desc: str
1517 ) -> int:
1518 """Executes a DELETE query on the named table.
1519
1520 Filters rows by the key-value pairs.
1521
1522 Args:
1523 table: string giving the table name
1524 keyvalues: dict of column names and values to select the row with
1525 desc: description of the transaction, for logging and metrics
1526
1527 Returns:
1528 The number of deleted rows.
1529 """
1530 return await self.runInteraction(desc, self.simple_delete_txn, table, keyvalues)
13131531
13141532 @staticmethod
1315 def simple_delete_txn(txn, table, keyvalues):
1533 def simple_delete_txn(
1534 txn: LoggingTransaction, table: str, keyvalues: Dict[str, Any]
1535 ) -> int:
1536 """Executes a DELETE query on the named table.
1537
1538 Filters rows by the key-value pairs.
1539
1540 Args:
1541 table: string giving the table name
1542 keyvalues: dict of column names and values to select the row with
1543
1544 Returns:
1545 The number of deleted rows.
1546 """
13161547 sql = "DELETE FROM %s WHERE %s" % (
13171548 table,
13181549 " AND ".join("%s = ?" % (k,) for k in keyvalues),
13211552 txn.execute(sql, list(keyvalues.values()))
13221553 return txn.rowcount
13231554
1324 def simple_delete_many(self, table, column, iterable, keyvalues, desc):
1325 return self.runInteraction(
1555 async def simple_delete_many(
1556 self,
1557 table: str,
1558 column: str,
1559 iterable: Iterable[Any],
1560 keyvalues: Dict[str, Any],
1561 desc: str,
1562 ) -> int:
1563 """Executes a DELETE query on the named table.
1564
1565 Filters rows by if value of `column` is in `iterable`.
1566
1567 Args:
1568 table: string giving the table name
1569 column: column name to test for inclusion against `iterable`
1570 iterable: list
1571 keyvalues: dict of column names and values to select the rows with
1572 desc: description of the transaction, for logging and metrics
1573
1574 Returns:
1575 Number rows deleted
1576 """
1577 return await self.runInteraction(
13261578 desc, self.simple_delete_many_txn, table, column, iterable, keyvalues
13271579 )
13281580
13291581 @staticmethod
1330 def simple_delete_many_txn(txn, table, column, iterable, keyvalues):
1582 def simple_delete_many_txn(
1583 txn: LoggingTransaction,
1584 table: str,
1585 column: str,
1586 iterable: Iterable[Any],
1587 keyvalues: Dict[str, Any],
1588 ) -> int:
13311589 """Executes a DELETE query on the named table.
13321590
13331591 Filters rows by if value of `column` is in `iterable`.
13341592
13351593 Args:
1336 txn : Transaction object
1337 table : string giving the table name
1338 column : column name to test for inclusion against `iterable`
1339 iterable : list
1340 keyvalues : dict of column names and values to select the rows with
1594 txn: Transaction object
1595 table: string giving the table name
1596 column: column name to test for inclusion against `iterable`
1597 iterable: list
1598 keyvalues: dict of column names and values to select the rows with
13411599
13421600 Returns:
1343 int: Number rows deleted
1601 Number rows deleted
13441602 """
13451603 if not iterable:
13461604 return 0
13611619 return txn.rowcount
13621620
13631621 def get_cache_dict(
1364 self, db_conn, table, entity_column, stream_column, max_value, limit=100000
1365 ):
1622 self,
1623 db_conn: Connection,
1624 table: str,
1625 entity_column: str,
1626 stream_column: str,
1627 max_value: int,
1628 limit: int = 100000,
1629 ) -> Tuple[Dict[Any, int], int]:
13661630 # Fetch a mapping of room_id -> max stream position for "recent" rooms.
13671631 # It doesn't really matter how many we get, the StreamChangeCache will
13681632 # do the right thing to ensure it respects the max size of cache.
13931657
13941658 return cache, min_val
13951659
1396 def simple_select_list_paginate(
1397 self,
1398 table,
1399 orderby,
1400 start,
1401 limit,
1402 retcols,
1403 filters=None,
1404 keyvalues=None,
1405 order_direction="ASC",
1406 desc="simple_select_list_paginate",
1407 ):
1660 @classmethod
1661 def simple_select_list_paginate_txn(
1662 cls,
1663 txn: LoggingTransaction,
1664 table: str,
1665 orderby: str,
1666 start: int,
1667 limit: int,
1668 retcols: Iterable[str],
1669 filters: Optional[Dict[str, Any]] = None,
1670 keyvalues: Optional[Dict[str, Any]] = None,
1671 order_direction: str = "ASC",
1672 ) -> List[Dict[str, Any]]:
14081673 """
14091674 Executes a SELECT query on the named table with start and limit,
14101675 of row numbers, which may return zero or number of rows from start to limit,
14111676 returning the result as a list of dicts.
14121677
1413 Args:
1414 table (str): the table name
1415 filters (dict[str, T] | None):
1416 column names and values to filter the rows with, or None to not
1417 apply a WHERE ? LIKE ? clause.
1418 keyvalues (dict[str, T] | None):
1419 column names and values to select the rows with, or None to not
1420 apply a WHERE clause.
1421 orderby (str): Column to order the results by.
1422 start (int): Index to begin the query at.
1423 limit (int): Number of results to return.
1424 retcols (iterable[str]): the names of the columns to return
1425 order_direction (str): Whether the results should be ordered "ASC" or "DESC".
1426 Returns:
1427 defer.Deferred: resolves to list[dict[str, Any]]
1428 """
1429 return self.runInteraction(
1430 desc,
1431 self.simple_select_list_paginate_txn,
1432 table,
1433 orderby,
1434 start,
1435 limit,
1436 retcols,
1437 filters=filters,
1438 keyvalues=keyvalues,
1439 order_direction=order_direction,
1440 )
1441
1442 @classmethod
1443 def simple_select_list_paginate_txn(
1444 cls,
1445 txn,
1446 table,
1447 orderby,
1448 start,
1449 limit,
1450 retcols,
1451 filters=None,
1452 keyvalues=None,
1453 order_direction="ASC",
1454 ):
1455 """
1456 Executes a SELECT query on the named table with start and limit,
1457 of row numbers, which may return zero or number of rows from start to limit,
1458 returning the result as a list of dicts.
1459
14601678 Use `filters` to search attributes using SQL wildcards and/or `keyvalues` to
14611679 select attributes with exact matches. All constraints are joined together
14621680 using 'AND'.
14631681
14641682 Args:
1465 txn : Transaction object
1466 table (str): the table name
1467 orderby (str): Column to order the results by.
1468 start (int): Index to begin the query at.
1469 limit (int): Number of results to return.
1470 retcols (iterable[str]): the names of the columns to return
1471 filters (dict[str, T] | None):
1683 txn: Transaction object
1684 table: the table name
1685 orderby: Column to order the results by.
1686 start: Index to begin the query at.
1687 limit: Number of results to return.
1688 retcols: the names of the columns to return
1689 filters:
14721690 column names and values to filter the rows with, or None to not
14731691 apply a WHERE ? LIKE ? clause.
1474 keyvalues (dict[str, T] | None):
1692 keyvalues:
14751693 column names and values to select the rows with, or None to not
14761694 apply a WHERE clause.
1477 order_direction (str): Whether the results should be ordered "ASC" or "DESC".
1695 order_direction: Whether the results should be ordered "ASC" or "DESC".
1696
14781697 Returns:
1479 defer.Deferred: resolves to list[dict[str, Any]]
1698 The result as a list of dictionaries.
14801699 """
14811700 if order_direction not in ["ASC", "DESC"]:
14821701 raise ValueError("order_direction must be one of 'ASC' or 'DESC'.")
15021721
15031722 return cls.cursor_to_dict(txn)
15041723
1505 def simple_search_list(self, table, term, col, retcols, desc="simple_search_list"):
1724 async def simple_search_list(
1725 self,
1726 table: str,
1727 term: Optional[str],
1728 col: str,
1729 retcols: Iterable[str],
1730 desc="simple_search_list",
1731 ) -> Optional[List[Dict[str, Any]]]:
15061732 """Executes a SELECT query on the named table, which may return zero or
15071733 more rows, returning the result as a list of dicts.
15081734
15091735 Args:
1510 table (str): the table name
1511 term (str | None):
1512 term for searching the table matched to a column.
1513 col (str): column to query term should be matched to
1514 retcols (iterable[str]): the names of the columns to return
1736 table: the table name
1737 term: term for searching the table matched to a column.
1738 col: column to query term should be matched to
1739 retcols: the names of the columns to return
1740
15151741 Returns:
1516 defer.Deferred: resolves to list[dict[str, Any]] or None
1517 """
1518
1519 return self.runInteraction(
1742 A list of dictionaries or None.
1743 """
1744
1745 return await self.runInteraction(
15201746 desc, self.simple_search_list_txn, table, term, col, retcols
15211747 )
15221748
15231749 @classmethod
1524 def simple_search_list_txn(cls, txn, table, term, col, retcols):
1750 def simple_search_list_txn(
1751 cls,
1752 txn: LoggingTransaction,
1753 table: str,
1754 term: Optional[str],
1755 col: str,
1756 retcols: Iterable[str],
1757 ) -> Optional[List[Dict[str, Any]]]:
15251758 """Executes a SELECT query on the named table, which may return zero or
15261759 more rows, returning the result as a list of dicts.
15271760
15281761 Args:
1529 txn : Transaction object
1530 table (str): the table name
1531 term (str | None):
1532 term for searching the table matched to a column.
1533 col (str): column to query term should be matched to
1534 retcols (iterable[str]): the names of the columns to return
1762 txn: Transaction object
1763 table: the table name
1764 term: term for searching the table matched to a column.
1765 col: column to query term should be matched to
1766 retcols: the names of the columns to return
1767
15351768 Returns:
1536 defer.Deferred: resolves to list[dict[str, Any]] or None
1769 None if no term is given, otherwise a list of dictionaries.
15371770 """
15381771 if term:
15391772 sql = "SELECT %s FROM %s WHERE %s LIKE ?" % (", ".join(retcols), table, col)
15401773 termvalues = ["%%" + term + "%%"]
15411774 txn.execute(sql, termvalues)
15421775 else:
1543 return 0
1776 return None
15441777
15451778 return cls.cursor_to_dict(txn)
15461779
15471780
15481781 def make_in_list_sql_clause(
1549 database_engine, column: str, iterable: Iterable
1782 database_engine: BaseDatabaseEngine, column: str, iterable: Iterable
15501783 ) -> Tuple[str, list]:
15511784 """Returns an SQL clause that checks the given column is in the iterable.
15521785
2323 logger = logging.getLogger(__name__)
2424
2525
26 class Databases(object):
26 class Databases:
2727 """The various databases.
2828
2929 These are low level interfaces to physical databases.
4646 engine = create_engine(database_config.config)
4747
4848 with make_conn(database_config, engine) as db_conn:
49 logger.info("Preparing database %r...", db_name)
49 logger.info("[database config %r]: Checking database server", db_name)
50 engine.check_database(db_conn)
5051
51 engine.check_database(db_conn)
52 logger.info(
53 "[database config %r]: Preparing for databases %r",
54 db_name,
55 database_config.databases,
56 )
5257 prepare_database(
5358 db_conn, engine, hs.config, databases=database_config.databases,
5459 )
5661 database = DatabasePool(hs, database_config, engine)
5762
5863 if "main" in database_config.databases:
59 logger.info("Starting 'main' data store")
64 logger.info(
65 "[database config %r]: Starting 'main' database", db_name
66 )
6067
6168 # Sanity check we don't try and configure the main store on
6269 # multiple databases.
7178 persist_events = PersistEventsStore(hs, database, main)
7279
7380 if "state" in database_config.databases:
74 logger.info("Starting 'state' data store")
81 logger.info(
82 "[database config %r]: Starting 'state' database", db_name
83 )
7584
7685 # Sanity check we don't try and configure the state store on
7786 # multiple databases.
8493
8594 self.databases.append(database)
8695
87 logger.info("Database %r prepared", db_name)
96 logger.info("[database config %r]: prepared", db_name)
97
98 # Closing the context manager doesn't close the connection.
99 # psycopg will close the connection when the object gets GCed, but *only*
100 # if the PID is the same as when the connection was opened [1], and
101 # it may not be if we fork in the meantime.
102 #
103 # [1]: https://github.com/psycopg/psycopg2/blob/2_8_5/psycopg/connection_type.c#L1378
104
105 db_conn.close()
88106
89107 # Sanity check that we have actually configured all the required stores.
90108 if not main:
91 raise Exception("No 'main' data store configured")
109 raise Exception("No 'main' database configured")
92110
93111 if not state:
94 raise Exception("No 'main' data store configured")
112 raise Exception("No 'state' database configured")
95113
96114 # We use local variables here to ensure that the databases do not have
97115 # optional types.
1717 import calendar
1818 import logging
1919 import time
20 from typing import Any, Dict, List, Optional, Tuple
2021
2122 from synapse.api.constants import PresenceState
2223 from synapse.config.homeserver import HomeServerConfig
2728 MultiWriterIdGenerator,
2829 StreamIdGenerator,
2930 )
31 from synapse.types import get_domain_from_id
3032 from synapse.util.caches.stream_change_cache import StreamChangeCache
3133
3234 from .account_data import AccountDataStore
262264 # Used in _generate_user_daily_visits to keep track of progress
263265 self._last_user_visit_update = self._get_start_of_day()
264266
267 def get_device_stream_token(self) -> int:
268 return self._device_list_id_gen.get_current_token()
269
265270 def take_presence_startup_info(self):
266271 active_on_startup = self._presence_on_startup
267272 self._presence_on_startup = None
289294
290295 return [UserPresenceState(**row) for row in rows]
291296
292 def count_daily_users(self):
297 async def count_daily_users(self) -> int:
293298 """
294299 Counts the number of users who used this homeserver in the last 24 hours.
295300 """
296301 yesterday = int(self._clock.time_msec()) - (1000 * 60 * 60 * 24)
297 return self.db_pool.runInteraction(
302 return await self.db_pool.runInteraction(
298303 "count_daily_users", self._count_users, yesterday
299304 )
300305
301 def count_monthly_users(self):
306 async def count_monthly_users(self) -> int:
302307 """
303308 Counts the number of users who used this homeserver in the last 30 days.
304309 Note this method is intended for phonehome metrics only and is different
306311 amongst other things, includes a 3 day grace period before a user counts.
307312 """
308313 thirty_days_ago = int(self._clock.time_msec()) - (1000 * 60 * 60 * 24 * 30)
309 return self.db_pool.runInteraction(
314 return await self.db_pool.runInteraction(
310315 "count_monthly_users", self._count_users, thirty_days_ago
311316 )
312317
325330 (count,) = txn.fetchone()
326331 return count
327332
328 def count_r30_users(self):
333 async def count_r30_users(self) -> Dict[str, int]:
329334 """
330335 Counts the number of 30 day retained users, defined as:-
331336 * Users who have created their accounts more than 30 days ago
332337 * Where last seen at most 30 days ago
333338 * Where account creation and last_seen are > 30 days apart
334339
335 Returns counts globaly for a given user as well as breaking
336 by platform
340 Returns:
341 A mapping of counts globally as well as broken out by platform.
337342 """
338343
339344 def _count_r30_users(txn):
406411
407412 return results
408413
409 return self.db_pool.runInteraction("count_r30_users", _count_r30_users)
414 return await self.db_pool.runInteraction("count_r30_users", _count_r30_users)
410415
411416 def _get_start_of_day(self):
412417 """
416421 today_start = calendar.timegm((now.tm_year, now.tm_mon, now.tm_mday, 0, 0, 0))
417422 return today_start * 1000
418423
419 def generate_user_daily_visits(self):
424 async def generate_user_daily_visits(self) -> None:
420425 """
421426 Generates daily visit data for use in cohort/ retention analysis
422427 """
471476 # frequently
472477 self._last_user_visit_update = now
473478
474 return self.db_pool.runInteraction(
479 await self.db_pool.runInteraction(
475480 "generate_user_daily_visits", _generate_user_daily_visits
476481 )
477482
478 def get_users(self):
483 async def get_users(self) -> List[Dict[str, Any]]:
479484 """Function to retrieve a list of users in users table.
480485
481 Args:
482486 Returns:
483 defer.Deferred: resolves to list[dict[str, Any]]
484 """
485 return self.db_pool.simple_select_list(
487 A list of dictionaries representing users.
488 """
489 return await self.db_pool.simple_select_list(
486490 table="users",
487491 keyvalues={},
488492 retcols=[
496500 desc="get_users",
497501 )
498502
499 def get_users_paginate(
500 self, start, limit, name=None, guests=True, deactivated=False
501 ):
503 async def get_users_paginate(
504 self,
505 start: int,
506 limit: int,
507 user_id: Optional[str] = None,
508 name: Optional[str] = None,
509 guests: bool = True,
510 deactivated: bool = False,
511 ) -> Tuple[List[Dict[str, Any]], int]:
502512 """Function to retrieve a paginated list of users from
503513 users list. This will return a json list of users and the
504514 total number of users matching the filter criteria.
505515
506516 Args:
507 start (int): start number to begin the query from
508 limit (int): number of rows to retrieve
509 name (string): filter for user names
510 guests (bool): whether to in include guest users
511 deactivated (bool): whether to include deactivated users
517 start: start number to begin the query from
518 limit: number of rows to retrieve
519 user_id: search for user_id. ignored if name is not None
520 name: search for local part of user_id or display name
521 guests: whether to in include guest users
522 deactivated: whether to include deactivated users
512523 Returns:
513 defer.Deferred: resolves to list[dict[str, Any]], int
524 A tuple of a list of mappings from user to information and a count of total users.
514525 """
515526
516527 def get_users_paginate_txn(txn):
517528 filters = []
518 args = []
529 args = [self.hs.config.server_name]
519530
520531 if name:
532 filters.append("(name LIKE ? OR displayname LIKE ?)")
533 args.extend(["@%" + name + "%:%", "%" + name + "%"])
534 elif user_id:
521535 filters.append("name LIKE ?")
522 args.append("%" + name + "%")
536 args.extend(["%" + user_id + "%"])
523537
524538 if not guests:
525539 filters.append("is_guest = 0")
529543
530544 where_clause = "WHERE " + " AND ".join(filters) if len(filters) > 0 else ""
531545
532 sql = "SELECT COUNT(*) as total_users FROM users %s" % (where_clause)
533 txn.execute(sql, args)
534 count = txn.fetchone()[0]
535
536 args = [self.hs.config.server_name] + args + [limit, start]
537 sql = """
538 SELECT name, user_type, is_guest, admin, deactivated, displayname, avatar_url
546 sql_base = """
539547 FROM users as u
540548 LEFT JOIN profiles AS p ON u.name = '@' || p.user_id || ':' || ?
541549 {}
542 ORDER BY u.name LIMIT ? OFFSET ?
543550 """.format(
544551 where_clause
545552 )
553 sql = "SELECT COUNT(*) as total_users " + sql_base
554 txn.execute(sql, args)
555 count = txn.fetchone()[0]
556
557 sql = (
558 "SELECT name, user_type, is_guest, admin, deactivated, displayname, avatar_url "
559 + sql_base
560 + " ORDER BY u.name LIMIT ? OFFSET ?"
561 )
562 args += [limit, start]
546563 txn.execute(sql, args)
547564 users = self.db_pool.cursor_to_dict(txn)
548565 return users, count
549566
550 return self.db_pool.runInteraction(
567 return await self.db_pool.runInteraction(
551568 "get_users_paginate_txn", get_users_paginate_txn
552569 )
553570
554 def search_users(self, term):
571 async def search_users(self, term: str) -> Optional[List[Dict[str, Any]]]:
555572 """Function to search users list for one or more users with
556573 the matched term.
557574
558575 Args:
559 term (str): search term
560 col (str): column to query term should be matched to
576 term: search term
577
561578 Returns:
562 defer.Deferred: resolves to list[dict[str, Any]]
563 """
564 return self.db_pool.simple_search_list(
579 A list of dictionaries or None.
580 """
581 return await self.db_pool.simple_search_list(
565582 table="users",
566583 term=term,
567584 col="name",
574591 """Called before upgrading an existing database to check that it is broadly sane
575592 compared with the configuration.
576593 """
577 domain = config.server_name
578
579 sql = database_engine.convert_param_style(
580 "SELECT COUNT(*) FROM users WHERE name NOT LIKE ?"
581 )
582 pat = "%:" + domain
583 cur.execute(sql, (pat,))
584 num_not_matching = cur.fetchall()[0][0]
585 if num_not_matching == 0:
594 logger.info("Checking database for consistency with configuration...")
595
596 # if there are any users in the database, check that the username matches our
597 # configured server name.
598
599 cur.execute("SELECT name FROM users LIMIT 1")
600 rows = cur.fetchall()
601 if not rows:
602 return
603
604 user_domain = get_domain_from_id(rows[0][0])
605 if user_domain == config.server_name:
586606 return
587607
588608 raise Exception(
589609 "Found users in database not native to %s!\n"
590 "You cannot changed a synapse server_name after it's been configured"
591 % (domain,)
610 "You cannot change a synapse server_name after it's been configured"
611 % (config.server_name,)
592612 )
593613
594614
1515
1616 import abc
1717 import logging
18 from typing import List, Optional, Tuple
19
20 from twisted.internet import defer
18 from typing import Dict, List, Optional, Tuple
2119
2220 from synapse.storage._base import SQLBaseStore, db_to_json
2321 from synapse.storage.database import DatabasePool
5755 raise NotImplementedError()
5856
5957 @cached()
60 def get_account_data_for_user(self, user_id):
58 async def get_account_data_for_user(
59 self, user_id: str
60 ) -> Tuple[Dict[str, JsonDict], Dict[str, Dict[str, JsonDict]]]:
6161 """Get all the client account_data for a user.
6262
6363 Args:
64 user_id(str): The user to get the account_data for.
65 Returns:
66 A deferred pair of a dict of global account_data and a dict
67 mapping from room_id string to per room account_data dicts.
64 user_id: The user to get the account_data for.
65 Returns:
66 A 2-tuple of a dict of global account_data and a dict mapping from
67 room_id string to per room account_data dicts.
6868 """
6969
7070 def get_account_data_for_user_txn(txn):
9393
9494 return global_account_data, by_room
9595
96 return self.db_pool.runInteraction(
96 return await self.db_pool.runInteraction(
9797 "get_account_data_for_user", get_account_data_for_user_txn
9898 )
9999
119119 return None
120120
121121 @cached(num_args=2)
122 def get_account_data_for_room(self, user_id, room_id):
122 async def get_account_data_for_room(
123 self, user_id: str, room_id: str
124 ) -> Dict[str, JsonDict]:
123125 """Get all the client account_data for a user for a room.
124126
125127 Args:
126 user_id(str): The user to get the account_data for.
127 room_id(str): The room to get the account_data for.
128 Returns:
129 A deferred dict of the room account_data
128 user_id: The user to get the account_data for.
129 room_id: The room to get the account_data for.
130 Returns:
131 A dict of the room account_data
130132 """
131133
132134 def get_account_data_for_room_txn(txn):
141143 row["account_data_type"]: db_to_json(row["content"]) for row in rows
142144 }
143145
144 return self.db_pool.runInteraction(
146 return await self.db_pool.runInteraction(
145147 "get_account_data_for_room", get_account_data_for_room_txn
146148 )
147149
148150 @cached(num_args=3, max_entries=5000)
149 def get_account_data_for_room_and_type(self, user_id, room_id, account_data_type):
151 async def get_account_data_for_room_and_type(
152 self, user_id: str, room_id: str, account_data_type: str
153 ) -> Optional[JsonDict]:
150154 """Get the client account_data of given type for a user for a room.
151155
152156 Args:
153 user_id(str): The user to get the account_data for.
154 room_id(str): The room to get the account_data for.
155 account_data_type (str): The account data type to get.
156 Returns:
157 A deferred of the room account_data for that type, or None if
158 there isn't any set.
157 user_id: The user to get the account_data for.
158 room_id: The room to get the account_data for.
159 account_data_type: The account data type to get.
160 Returns:
161 The room account_data for that type, or None if there isn't any set.
159162 """
160163
161164 def get_account_data_for_room_and_type_txn(txn):
173176
174177 return db_to_json(content_json) if content_json else None
175178
176 return self.db_pool.runInteraction(
179 return await self.db_pool.runInteraction(
177180 "get_account_data_for_room_and_type", get_account_data_for_room_and_type_txn
178181 )
179182
237240 "get_updated_room_account_data", get_updated_room_account_data_txn
238241 )
239242
240 def get_updated_account_data_for_user(self, user_id, stream_id):
243 async def get_updated_account_data_for_user(
244 self, user_id: str, stream_id: int
245 ) -> Tuple[Dict[str, JsonDict], Dict[str, Dict[str, JsonDict]]]:
241246 """Get all the client account_data for a that's changed for a user
242247
243248 Args:
244 user_id(str): The user to get the account_data for.
245 stream_id(int): The point in the stream since which to get updates
249 user_id: The user to get the account_data for.
250 stream_id: The point in the stream since which to get updates
246251 Returns:
247252 A deferred pair of a dict of global account_data and a dict
248253 mapping from room_id string to per room account_data dicts.
276281 user_id, int(stream_id)
277282 )
278283 if not changed:
279 return defer.succeed(({}, {}))
280
281 return self.db_pool.runInteraction(
284 return ({}, {})
285
286 return await self.db_pool.runInteraction(
282287 "get_updated_account_data_for_user", get_updated_account_data_for_user_txn
283288 )
284289
335340 """
336341 content_json = json_encoder.encode(content)
337342
338 with self._account_data_id_gen.get_next() as next_id:
343 with await self._account_data_id_gen.get_next() as next_id:
339344 # no need to lock here as room_account_data has a unique constraint
340345 # on (user_id, room_id, account_data_type) so simple_upsert will
341346 # retry if there is a conflict.
383388 """
384389 content_json = json_encoder.encode(content)
385390
386 with self._account_data_id_gen.get_next() as next_id:
391 with await self._account_data_id_gen.get_next() as next_id:
387392 # no need to lock here as account_data has a unique constraint on
388393 # (user_id, account_data_type) so simple_upsert will retry if
389394 # there is a conflict.
415420
416421 return self._account_data_id_gen.get_current_token()
417422
418 def _update_max_stream_id(self, next_id: int):
423 async def _update_max_stream_id(self, next_id: int) -> None:
419424 """Update the max stream_id
420425
421426 Args:
434439 )
435440 txn.execute(update_max_id_sql, (next_id, next_id))
436441
437 return self.db_pool.runInteraction("update_account_data_max_stream_id", _update)
442 await self.db_pool.runInteraction("update_account_data_max_stream_id", _update)
1515 import logging
1616 import re
1717
18 from canonicaljson import json
19
2018 from synapse.appservice import AppServiceTransaction
2119 from synapse.config.appservice import load_appservices
2220 from synapse.storage._base import SQLBaseStore, db_to_json
2321 from synapse.storage.database import DatabasePool
2422 from synapse.storage.databases.main.events_worker import EventsWorkerStore
23 from synapse.util import json_encoder
2524
2625 logger = logging.getLogger(__name__)
2726
161160 return result.get("state")
162161 return None
163162
164 def set_appservice_state(self, service, state):
163 async def set_appservice_state(self, service, state) -> None:
165164 """Set the application service state.
166165
167166 Args:
168167 service(ApplicationService): The service whose state to set.
169168 state(ApplicationServiceState): The connectivity state to apply.
170 Returns:
171 A Deferred which resolves when the state was set successfully.
172 """
173 return self.db_pool.simple_upsert(
169 """
170 await self.db_pool.simple_upsert(
174171 "application_services_state", {"as_id": service.id}, {"state": state}
175172 )
176173
177 def create_appservice_txn(self, service, events):
174 async def create_appservice_txn(self, service, events):
178175 """Atomically creates a new transaction for this application service
179176 with the given list of events.
180177
203200 new_txn_id = max(highest_txn_id, last_txn_id) + 1
204201
205202 # Insert new txn into txn table
206 event_ids = json.dumps([e.event_id for e in events])
203 event_ids = json_encoder.encode([e.event_id for e in events])
207204 txn.execute(
208205 "INSERT INTO application_services_txns(as_id, txn_id, event_ids) "
209206 "VALUES(?,?,?)",
211208 )
212209 return AppServiceTransaction(service=service, id=new_txn_id, events=events)
213210
214 return self.db_pool.runInteraction(
211 return await self.db_pool.runInteraction(
215212 "create_appservice_txn", _create_appservice_txn
216213 )
217214
218 def complete_appservice_txn(self, txn_id, service):
215 async def complete_appservice_txn(self, txn_id, service) -> None:
219216 """Completes an application service transaction.
220217
221218 Args:
222219 txn_id(str): The transaction ID being completed.
223220 service(ApplicationService): The application service which was sent
224221 this transaction.
225 Returns:
226 A Deferred which resolves if this transaction was stored
227 successfully.
228222 """
229223 txn_id = int(txn_id)
230224
260254 {"txn_id": txn_id, "as_id": service.id},
261255 )
262256
263 return self.db_pool.runInteraction(
257 await self.db_pool.runInteraction(
264258 "complete_appservice_txn", _complete_appservice_txn
265259 )
266260
314308 else:
315309 return int(last_txn_id[0]) # select 'last_txn' col
316310
317 def set_appservice_last_pos(self, pos):
311 async def set_appservice_last_pos(self, pos) -> None:
318312 def set_appservice_last_pos_txn(txn):
319313 txn.execute(
320314 "UPDATE appservice_stream_position SET stream_ordering = ?", (pos,)
321315 )
322316
323 return self.db_pool.runInteraction(
317 await self.db_pool.runInteraction(
324318 "set_appservice_last_pos", set_appservice_last_pos_txn
325319 )
326320
298298 },
299299 )
300300
301 def get_cache_stream_token(self, instance_name):
301 def get_cache_stream_token_for_writer(self, instance_name: str) -> int:
302302 if self._cache_id_gen:
303 return self._cache_id_gen.get_current_token(instance_name)
303 return self._cache_id_gen.get_current_token_for_writer(instance_name)
304304 else:
305305 return 0
395395 self._batch_row_update[key] = (user_agent, device_id, now)
396396
397397 @wrap_as_background_process("update_client_ips")
398 def _update_client_ips_batch(self):
398 async def _update_client_ips_batch(self) -> None:
399399
400400 # If the DB pool has already terminated, don't try updating
401401 if not self.db_pool.is_running():
404404 to_update = self._batch_row_update
405405 self._batch_row_update = {}
406406
407 return self.db_pool.runInteraction(
407 await self.db_pool.runInteraction(
408408 "_update_client_ips_batch", self._update_client_ips_batch_txn, to_update
409409 )
410410
189189 )
190190
191191 @trace
192 def delete_device_msgs_for_remote(self, destination, up_to_stream_id):
192 async def delete_device_msgs_for_remote(
193 self, destination: str, up_to_stream_id: int
194 ) -> None:
193195 """Used to delete messages when the remote destination acknowledges
194196 their receipt.
195197
196198 Args:
197 destination(str): The destination server_name
198 up_to_stream_id(int): Where to delete messages up to.
199 Returns:
200 A deferred that resolves when the messages have been deleted.
199 destination: The destination server_name
200 up_to_stream_id: Where to delete messages up to.
201201 """
202202
203203 def delete_messages_for_remote_destination_txn(txn):
208208 )
209209 txn.execute(sql, (destination, up_to_stream_id))
210210
211 return self.db_pool.runInteraction(
211 await self.db_pool.runInteraction(
212212 "delete_device_msgs_for_remote", delete_messages_for_remote_destination_txn
213213 )
214214
361361 rows.append((destination, stream_id, now_ms, edu_json))
362362 txn.executemany(sql, rows)
363363
364 with self._device_inbox_id_gen.get_next() as stream_id:
364 with await self._device_inbox_id_gen.get_next() as stream_id:
365365 now_ms = self.clock.time_msec()
366366 await self.db_pool.runInteraction(
367367 "add_messages_to_device_inbox", add_messages_txn, now_ms, stream_id
410410 txn, stream_id, local_messages_by_user_then_device
411411 )
412412
413 with self._device_inbox_id_gen.get_next() as stream_id:
413 with await self._device_inbox_id_gen.get_next() as stream_id:
414414 now_ms = self.clock.time_msec()
415415 await self.db_pool.runInteraction(
416416 "add_messages_from_remote_to_device_inbox",
1313 # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
1414 # See the License for the specific language governing permissions and
1515 # limitations under the License.
16 import abc
1617 import logging
17 from typing import Dict, Iterable, List, Optional, Set, Tuple
18 from typing import Any, Dict, Iterable, List, Optional, Set, Tuple
1819
1920 from synapse.api.errors import Codes, StoreError
2021 from synapse.logging.opentracing import (
4647
4748
4849 class DeviceWorkerStore(SQLBaseStore):
49 def get_device(self, user_id: str, device_id: str):
50 async def get_device(self, user_id: str, device_id: str) -> Dict[str, Any]:
5051 """Retrieve a device. Only returns devices that are not marked as
5152 hidden.
5253
5455 user_id: The ID of the user which owns the device
5556 device_id: The ID of the device to retrieve
5657 Returns:
57 defer.Deferred for a dict containing the device information
58 A dict containing the device information
5859 Raises:
5960 StoreError: if the device is not found
6061 """
61 return self.db_pool.simple_select_one(
62 return await self.db_pool.simple_select_one(
6263 table="devices",
6364 keyvalues={"user_id": user_id, "device_id": device_id, "hidden": False},
6465 retcols=("user_id", "device_id", "display_name"),
100101 update included in the response), and the list of updates, where
101102 each update is a pair of EDU type and EDU contents.
102103 """
103 now_stream_id = self._device_list_id_gen.get_current_token()
104 now_stream_id = self.get_device_stream_token()
104105
105106 has_changed = self._device_list_federation_stream_cache.has_entity_changed(
106107 destination, int(from_stream_id)
253254 List of objects representing an device update EDU
254255 """
255256 devices = (
256 await self.db_pool.runInteraction(
257 "_get_e2e_device_keys_txn",
258 self._get_e2e_device_keys_txn,
257 await self.get_e2e_device_keys_and_signatures(
259258 query_map.keys(),
260259 include_all_devices=True,
261260 include_deleted_devices=True,
291290 prev_id = stream_id
292291
293292 if device is not None:
294 key_json = device.get("key_json", None)
295 if key_json:
296 result["keys"] = db_to_json(key_json)
297
298 if "signatures" in device:
299 for sig_user_id, sigs in device["signatures"].items():
300 result["keys"].setdefault("signatures", {}).setdefault(
301 sig_user_id, {}
302 ).update(sigs)
303
304 device_display_name = device.get("device_display_name", None)
293 keys = device.keys
294 if keys:
295 result["keys"] = keys
296
297 device_display_name = device.display_name
305298 if device_display_name:
306299 result["device_display_name"] = device_display_name
307300 else:
311304
312305 return results
313306
314 def _get_last_device_update_for_remote_user(
307 async def _get_last_device_update_for_remote_user(
315308 self, destination: str, user_id: str, from_stream_id: int
316 ):
309 ) -> int:
317310 def f(txn):
318311 prev_sent_id_sql = """
319312 SELECT coalesce(max(stream_id), 0) as stream_id
324317 rows = txn.fetchall()
325318 return rows[0][0]
326319
327 return self.db_pool.runInteraction("get_last_device_update_for_remote_user", f)
328
329 def mark_as_sent_devices_by_remote(self, destination: str, stream_id: int):
320 return await self.db_pool.runInteraction(
321 "get_last_device_update_for_remote_user", f
322 )
323
324 async def mark_as_sent_devices_by_remote(
325 self, destination: str, stream_id: int
326 ) -> None:
330327 """Mark that updates have successfully been sent to the destination.
331328 """
332 return self.db_pool.runInteraction(
329 await self.db_pool.runInteraction(
333330 "mark_as_sent_devices_by_remote",
334331 self._mark_as_sent_devices_by_remote_txn,
335332 destination,
379376 THe new stream ID.
380377 """
381378
382 with self._device_list_id_gen.get_next() as stream_id:
379 with await self._device_list_id_gen.get_next() as stream_id:
383380 await self.db_pool.runInteraction(
384381 "add_user_sig_change_to_streams",
385382 self._add_user_signature_change_txn,
411408 },
412409 )
413410
411 @abc.abstractmethod
414412 def get_device_stream_token(self) -> int:
415 return self._device_list_id_gen.get_current_token()
413 """Get the current stream id from the _device_list_id_gen"""
414 ...
416415
417416 @trace
418417 async def get_user_devices_from_cache(
479478 return {
480479 device["device_id"]: db_to_json(device["content"]) for device in devices
481480 }
482
483 def get_devices_with_keys_by_user(self, user_id: str):
484 """Get all devices (with any device keys) for a user
485
486 Returns:
487 Deferred which resolves to (stream_id, devices)
488 """
489 return self.db_pool.runInteraction(
490 "get_devices_with_keys_by_user",
491 self._get_devices_with_keys_by_user_txn,
492 user_id,
493 )
494
495 def _get_devices_with_keys_by_user_txn(
496 self, txn: LoggingTransaction, user_id: str
497 ) -> Tuple[int, List[JsonDict]]:
498 now_stream_id = self._device_list_id_gen.get_current_token()
499
500 devices = self._get_e2e_device_keys_txn(
501 txn, [(user_id, None)], include_all_devices=True
502 )
503
504 if devices:
505 user_devices = devices[user_id]
506 results = []
507 for device_id, device in user_devices.items():
508 result = {"device_id": device_id}
509
510 key_json = device.get("key_json", None)
511 if key_json:
512 result["keys"] = db_to_json(key_json)
513
514 if "signatures" in device:
515 for sig_user_id, sigs in device["signatures"].items():
516 result["keys"].setdefault("signatures", {}).setdefault(
517 sig_user_id, {}
518 ).update(sigs)
519
520 device_display_name = device.get("device_display_name", None)
521 if device_display_name:
522 result["device_display_name"] = device_display_name
523
524 results.append(result)
525
526 return now_stream_id, results
527
528 return now_stream_id, []
529481
530482 async def get_users_whose_devices_changed(
531483 self, from_key: str, user_ids: Iterable[str]
655607 )
656608
657609 @cached(max_entries=10000)
658 def get_device_list_last_stream_id_for_remote(self, user_id: str):
610 async def get_device_list_last_stream_id_for_remote(
611 self, user_id: str
612 ) -> Optional[Any]:
659613 """Get the last stream_id we got for a user. May be None if we haven't
660614 got any information for them.
661615 """
662 return self.db_pool.simple_select_one_onecol(
616 return await self.db_pool.simple_select_one_onecol(
663617 table="device_lists_remote_extremeties",
664618 keyvalues={"user_id": user_id},
665619 retcol="stream_id",
670624 @cachedList(
671625 cached_method_name="get_device_list_last_stream_id_for_remote",
672626 list_name="user_ids",
673 inlineCallbacks=True,
674627 )
675 def get_device_list_last_stream_id_for_remotes(self, user_ids: str):
676 rows = yield self.db_pool.simple_select_many_batch(
628 async def get_device_list_last_stream_id_for_remotes(self, user_ids: str):
629 rows = await self.db_pool.simple_select_many_batch(
677630 table="device_lists_remote_extremeties",
678631 column="user_id",
679632 iterable=user_ids,
714667
715668 return {row["user_id"] for row in rows}
716669
717 def mark_remote_user_device_cache_as_stale(self, user_id: str):
670 async def mark_remote_user_device_cache_as_stale(self, user_id: str) -> None:
718671 """Records that the server has reason to believe the cache of the devices
719672 for the remote users is out of date.
720673 """
721 return self.db_pool.simple_upsert(
674 await self.db_pool.simple_upsert(
722675 table="device_lists_remote_resync",
723676 keyvalues={"user_id": user_id},
724677 values={},
726679 desc="make_remote_user_device_cache_as_stale",
727680 )
728681
729 def mark_remote_user_device_list_as_unsubscribed(self, user_id: str):
682 async def mark_remote_user_device_list_as_unsubscribed(self, user_id: str) -> None:
730683 """Mark that we no longer track device lists for remote user.
731684 """
732685
740693 txn, self.get_device_list_last_stream_id_for_remote, (user_id,)
741694 )
742695
743 return self.db_pool.runInteraction(
696 await self.db_pool.runInteraction(
744697 "mark_remote_user_device_list_as_unsubscribed",
745698 _mark_remote_user_device_list_as_unsubscribed_txn,
746699 )
1001954 desc="update_device",
1002955 )
1003956
1004 def update_remote_device_list_cache_entry(
957 async def update_remote_device_list_cache_entry(
1005958 self, user_id: str, device_id: str, content: JsonDict, stream_id: int
1006 ):
959 ) -> None:
1007960 """Updates a single device in the cache of a remote user's devicelist.
1008961
1009962 Note: assumes that we are the only thread that can be updating this user's
1014967 device_id: ID of decivice being updated
1015968 content: new data on this device
1016969 stream_id: the version of the device list
1017
1018 Returns:
1019 Deferred[None]
1020 """
1021 return self.db_pool.runInteraction(
970 """
971 await self.db_pool.runInteraction(
1022972 "update_remote_device_list_cache_entry",
1023973 self._update_remote_device_list_cache_entry_txn,
1024974 user_id,
10701020 lock=False,
10711021 )
10721022
1073 def update_remote_device_list_cache(
1023 async def update_remote_device_list_cache(
10741024 self, user_id: str, devices: List[dict], stream_id: int
1075 ):
1025 ) -> None:
10761026 """Replace the entire cache of the remote user's devices.
10771027
10781028 Note: assumes that we are the only thread that can be updating this user's
10821032 user_id: User to update device list for
10831033 devices: list of device objects supplied over federation
10841034 stream_id: the version of the device list
1085
1086 Returns:
1087 Deferred[None]
1088 """
1089 return self.db_pool.runInteraction(
1035 """
1036 await self.db_pool.runInteraction(
10901037 "update_remote_device_list_cache",
10911038 self._update_remote_device_list_cache_txn,
10921039 user_id,
10961043
10971044 def _update_remote_device_list_cache_txn(
10981045 self, txn: LoggingTransaction, user_id: str, devices: List[dict], stream_id: int
1099 ):
1046 ) -> None:
11001047 self.db_pool.simple_delete_txn(
11011048 txn, table="device_lists_remote_cache", keyvalues={"user_id": user_id}
11021049 )
11461093 if not device_ids:
11471094 return
11481095
1149 with self._device_list_id_gen.get_next_mult(len(device_ids)) as stream_ids:
1096 with await self._device_list_id_gen.get_next_mult(
1097 len(device_ids)
1098 ) as stream_ids:
11501099 await self.db_pool.runInteraction(
11511100 "add_device_change_to_stream",
11521101 self._add_device_change_to_stream_txn,
11591108 return stream_ids[-1]
11601109
11611110 context = get_active_span_text_map()
1162 with self._device_list_id_gen.get_next_mult(
1111 with await self._device_list_id_gen.get_next_mult(
11631112 len(hosts) * len(device_ids)
11641113 ) as stream_ids:
11651114 await self.db_pool.runInteraction(
1313 # limitations under the License.
1414
1515 from collections import namedtuple
16 from typing import Iterable, Optional
16 from typing import Iterable, List, Optional
1717
1818 from synapse.api.errors import SynapseError
1919 from synapse.storage._base import SQLBaseStore
5858
5959 return RoomAliasMapping(room_id, room_alias.to_string(), servers)
6060
61 def get_room_alias_creator(self, room_alias):
62 return self.db_pool.simple_select_one_onecol(
61 async def get_room_alias_creator(self, room_alias: str) -> str:
62 return await self.db_pool.simple_select_one_onecol(
6363 table="room_aliases",
6464 keyvalues={"room_alias": room_alias},
6565 retcol="creator",
6767 )
6868
6969 @cached(max_entries=5000)
70 def get_aliases_for_room(self, room_id):
71 return self.db_pool.simple_select_onecol(
70 async def get_aliases_for_room(self, room_id: str) -> List[str]:
71 return await self.db_pool.simple_select_onecol(
7272 "room_aliases",
7373 {"room_id": room_id},
7474 "room_alias",
158158
159159 return room_id
160160
161 def update_aliases_for_room(
161 async def update_aliases_for_room(
162162 self, old_room_id: str, new_room_id: str, creator: Optional[str] = None,
163 ):
163 ) -> None:
164164 """Repoint all of the aliases for a given room, to a different room.
165165
166166 Args:
188188 txn, self.get_aliases_for_room, (new_room_id,)
189189 )
190190
191 return self.db_pool.runInteraction(
191 await self.db_pool.runInteraction(
192192 "_update_aliases_for_room_txn", _update_aliases_for_room_txn
193193 )
1313 # See the License for the specific language governing permissions and
1414 # limitations under the License.
1515
16 from typing import Optional
17
1618 from synapse.api.errors import StoreError
1719 from synapse.logging.opentracing import log_kv, trace
1820 from synapse.storage._base import SQLBaseStore, db_to_json
148150
149151 return sessions
150152
151 def get_e2e_room_keys_multi(self, user_id, version, room_keys):
153 async def get_e2e_room_keys_multi(self, user_id, version, room_keys):
152154 """Get multiple room keys at a time. The difference between this function and
153155 get_e2e_room_keys is that this function can be used to retrieve
154156 multiple specific keys at a time, whereas get_e2e_room_keys is used for
163165 that we want to query
164166
165167 Returns:
166 Deferred[dict[str, dict[str, dict]]]: a map of room IDs to session IDs to room key
167 """
168
169 return self.db_pool.runInteraction(
168 dict[str, dict[str, dict]]: a map of room IDs to session IDs to room key
169 """
170
171 return await self.db_pool.runInteraction(
170172 "get_e2e_room_keys_multi",
171173 self._get_e2e_room_keys_multi_txn,
172174 user_id,
222224
223225 return ret
224226
225 def count_e2e_room_keys(self, user_id, version):
227 async def count_e2e_room_keys(self, user_id: str, version: str) -> int:
226228 """Get the number of keys in a backup version.
227229
228230 Args:
229 user_id (str): the user whose backup we're querying
230 version (str): the version ID of the backup we're querying about
231 """
232
233 return self.db_pool.simple_select_one_onecol(
231 user_id: the user whose backup we're querying
232 version: the version ID of the backup we're querying about
233 """
234
235 return await self.db_pool.simple_select_one_onecol(
234236 table="e2e_room_keys",
235237 keyvalues={"user_id": user_id, "version": version},
236238 retcol="COUNT(*)",
280282 raise StoreError(404, "No current backup version")
281283 return row[0]
282284
283 def get_e2e_room_keys_version_info(self, user_id, version=None):
285 async def get_e2e_room_keys_version_info(self, user_id, version=None):
284286 """Get info metadata about a version of our room_keys backup.
285287
286288 Args:
290292 Raises:
291293 StoreError: with code 404 if there are no e2e_room_keys_versions present
292294 Returns:
293 A deferred dict giving the info metadata for this backup version, with
295 A dict giving the info metadata for this backup version, with
294296 fields including:
295297 version(str)
296298 algorithm(str)
321323 result["etag"] = 0
322324 return result
323325
324 return self.db_pool.runInteraction(
326 return await self.db_pool.runInteraction(
325327 "get_e2e_room_keys_version_info", _get_e2e_room_keys_version_info_txn
326328 )
327329
328330 @trace
329 def create_e2e_room_keys_version(self, user_id, info):
331 async def create_e2e_room_keys_version(self, user_id: str, info: dict) -> str:
330332 """Atomically creates a new version of this user's e2e_room_keys store
331333 with the given version info.
332334
335337 info(dict): the info about the backup version to be created
336338
337339 Returns:
338 A deferred string for the newly created version ID
340 The newly created version ID
339341 """
340342
341343 def _create_e2e_room_keys_version_txn(txn):
362364
363365 return new_version
364366
365 return self.db_pool.runInteraction(
367 return await self.db_pool.runInteraction(
366368 "create_e2e_room_keys_version_txn", _create_e2e_room_keys_version_txn
367369 )
368370
369371 @trace
370 def update_e2e_room_keys_version(
371 self, user_id, version, info=None, version_etag=None
372 ):
372 async def update_e2e_room_keys_version(
373 self,
374 user_id: str,
375 version: str,
376 info: Optional[dict] = None,
377 version_etag: Optional[int] = None,
378 ) -> None:
373379 """Update a given backup version
374380
375381 Args:
376 user_id(str): the user whose backup version we're updating
377 version(str): the version ID of the backup version we're updating
378 info (dict): the new backup version info to store. If None, then
379 the backup version info is not updated
380 version_etag (Optional[int]): etag of the keys in the backup. If
381 None, then the etag is not updated
382 user_id: the user whose backup version we're updating
383 version: the version ID of the backup version we're updating
384 info: the new backup version info to store. If None, then the backup
385 version info is not updated.
386 version_etag: etag of the keys in the backup. If None, then the etag
387 is not updated.
382388 """
383389 updatevalues = {}
384390
388394 updatevalues["etag"] = version_etag
389395
390396 if updatevalues:
391 return self.db_pool.simple_update(
397 await self.db_pool.simple_update(
392398 table="e2e_room_keys_versions",
393399 keyvalues={"user_id": user_id, "version": version},
394400 updatevalues=updatevalues,
396402 )
397403
398404 @trace
399 def delete_e2e_room_keys_version(self, user_id, version=None):
405 async def delete_e2e_room_keys_version(
406 self, user_id: str, version: Optional[str] = None
407 ) -> None:
400408 """Delete a given backup version of the user's room keys.
401409 Doesn't delete their actual key data.
402410
403411 Args:
404 user_id(str): the user whose backup version we're deleting
405 version(str): Optional. the version ID of the backup version we're deleting
412 user_id: the user whose backup version we're deleting
413 version: Optional. the version ID of the backup version we're deleting
406414 If missing, we delete the current backup version info.
407415 Raises:
408416 StoreError: with code 404 if there are no e2e_room_keys_versions present,
423431 keyvalues={"user_id": user_id, "version": this_version},
424432 )
425433
426 return self.db_pool.simple_update_one_txn(
434 self.db_pool.simple_update_one_txn(
427435 txn,
428436 table="e2e_room_keys_versions",
429437 keyvalues={"user_id": user_id, "version": this_version},
430438 updatevalues={"deleted": 1},
431439 )
432440
433 return self.db_pool.runInteraction(
441 await self.db_pool.runInteraction(
434442 "delete_e2e_room_keys_version", _delete_e2e_room_keys_version_txn
435443 )
1313 # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
1414 # See the License for the specific language governing permissions and
1515 # limitations under the License.
16 from typing import Dict, Iterable, List, Optional, Tuple
17
16 import abc
17 from typing import TYPE_CHECKING, Dict, Iterable, List, Optional, Tuple
18
19 import attr
1820 from canonicaljson import encode_canonical_json
1921
2022 from twisted.enterprise.adbapi import Connection
2224 from synapse.logging.opentracing import log_kv, set_tag, trace
2325 from synapse.storage._base import SQLBaseStore, db_to_json
2426 from synapse.storage.database import make_in_list_sql_clause
27 from synapse.storage.types import Cursor
28 from synapse.types import JsonDict
2529 from synapse.util import json_encoder
2630 from synapse.util.caches.descriptors import cached, cachedList
2731 from synapse.util.iterutils import batch_iter
2832
33 if TYPE_CHECKING:
34 from synapse.handlers.e2e_keys import SignatureListItem
35
36
37 @attr.s
38 class DeviceKeyLookupResult:
39 """The type returned by get_e2e_device_keys_and_signatures"""
40
41 display_name = attr.ib(type=Optional[str])
42
43 # the key data from e2e_device_keys_json. Typically includes fields like
44 # "algorithm", "keys" (including the curve25519 identity key and the ed25519 signing
45 # key) and "signatures" (a map from (user id) to (key id/device_id) to signature.)
46 keys = attr.ib(type=Optional[JsonDict])
47
2948
3049 class EndToEndKeyWorkerStore(SQLBaseStore):
50 async def get_e2e_device_keys_for_federation_query(
51 self, user_id: str
52 ) -> Tuple[int, List[JsonDict]]:
53 """Get all devices (with any device keys) for a user
54
55 Returns:
56 (stream_id, devices)
57 """
58 now_stream_id = self.get_device_stream_token()
59
60 devices = await self.get_e2e_device_keys_and_signatures([(user_id, None)])
61
62 if devices:
63 user_devices = devices[user_id]
64 results = []
65 for device_id, device in user_devices.items():
66 result = {"device_id": device_id}
67
68 keys = device.keys
69 if keys:
70 result["keys"] = keys
71
72 device_display_name = device.display_name
73 if device_display_name:
74 result["device_display_name"] = device_display_name
75
76 results.append(result)
77
78 return now_stream_id, results
79
80 return now_stream_id, []
81
3182 @trace
32 async def get_e2e_device_keys(
33 self, query_list, include_all_devices=False, include_deleted_devices=False
34 ):
35 """Fetch a list of device keys.
83 async def get_e2e_device_keys_for_cs_api(
84 self, query_list: List[Tuple[str, Optional[str]]]
85 ) -> Dict[str, Dict[str, JsonDict]]:
86 """Fetch a list of device keys, formatted suitably for the C/S API.
3687 Args:
3788 query_list(list): List of pairs of user_ids and device_ids.
38 include_all_devices (bool): whether to include entries for devices
39 that don't have device keys
40 include_deleted_devices (bool): whether to include null entries for
41 devices which no longer exist (but were in the query_list).
42 This option only takes effect if include_all_devices is true.
4389 Returns:
4490 Dict mapping from user-id to dict mapping from device_id to
4591 key data. The key data will be a dict in the same format as the
4995 if not query_list:
5096 return {}
5197
52 results = await self.db_pool.runInteraction(
53 "get_e2e_device_keys",
54 self._get_e2e_device_keys_txn,
55 query_list,
56 include_all_devices,
57 include_deleted_devices,
58 )
98 results = await self.get_e2e_device_keys_and_signatures(query_list)
5999
60100 # Build the result structure, un-jsonify the results, and add the
61101 # "unsigned" section
63103 for user_id, device_keys in results.items():
64104 rv[user_id] = {}
65105 for device_id, device_info in device_keys.items():
66 r = db_to_json(device_info.pop("key_json"))
106 r = device_info.keys
67107 r["unsigned"] = {}
68 display_name = device_info["device_display_name"]
108 display_name = device_info.display_name
69109 if display_name is not None:
70110 r["unsigned"]["device_display_name"] = display_name
71 if "signatures" in device_info:
72 for sig_user_id, sigs in device_info["signatures"].items():
73 r.setdefault("signatures", {}).setdefault(
74 sig_user_id, {}
75 ).update(sigs)
76111 rv[user_id][device_id] = r
77112
78113 return rv
79114
80115 @trace
116 async def get_e2e_device_keys_and_signatures(
117 self,
118 query_list: List[Tuple[str, Optional[str]]],
119 include_all_devices: bool = False,
120 include_deleted_devices: bool = False,
121 ) -> Dict[str, Dict[str, Optional[DeviceKeyLookupResult]]]:
122 """Fetch a list of device keys
123
124 Any cross-signatures made on the keys by the owner of the device are also
125 included.
126
127 The cross-signatures are added to the `signatures` field within the `keys`
128 object in the response.
129
130 Args:
131 query_list: List of pairs of user_ids and device_ids. Device id can be None
132 to indicate "all devices for this user"
133
134 include_all_devices: whether to return devices without device keys
135
136 include_deleted_devices: whether to include null entries for
137 devices which no longer exist (but were in the query_list).
138 This option only takes effect if include_all_devices is true.
139
140 Returns:
141 Dict mapping from user-id to dict mapping from device_id to
142 key data.
143 """
144 set_tag("include_all_devices", include_all_devices)
145 set_tag("include_deleted_devices", include_deleted_devices)
146
147 result = await self.db_pool.runInteraction(
148 "get_e2e_device_keys",
149 self._get_e2e_device_keys_txn,
150 query_list,
151 include_all_devices,
152 include_deleted_devices,
153 )
154
155 # get the (user_id, device_id) tuples to look up cross-signatures for
156 signature_query = (
157 (user_id, device_id)
158 for user_id, dev in result.items()
159 for device_id, d in dev.items()
160 if d is not None and d.keys is not None
161 )
162
163 for batch in batch_iter(signature_query, 50):
164 cross_sigs_result = await self.db_pool.runInteraction(
165 "get_e2e_cross_signing_signatures",
166 self._get_e2e_cross_signing_signatures_for_devices_txn,
167 batch,
168 )
169
170 # add each cross-signing signature to the correct device in the result dict.
171 for (user_id, key_id, device_id, signature) in cross_sigs_result:
172 target_device_result = result[user_id][device_id]
173 target_device_signatures = target_device_result.keys.setdefault(
174 "signatures", {}
175 )
176 signing_user_signatures = target_device_signatures.setdefault(
177 user_id, {}
178 )
179 signing_user_signatures[key_id] = signature
180
181 log_kv(result)
182 return result
183
81184 def _get_e2e_device_keys_txn(
82185 self, txn, query_list, include_all_devices=False, include_deleted_devices=False
83 ):
84 set_tag("include_all_devices", include_all_devices)
85 set_tag("include_deleted_devices", include_deleted_devices)
86
186 ) -> Dict[str, Dict[str, Optional[DeviceKeyLookupResult]]]:
187 """Get information on devices from the database
188
189 The results include the device's keys and self-signatures, but *not* any
190 cross-signing signatures which have been added subsequently (for which, see
191 get_e2e_device_keys_and_signatures)
192 """
87193 query_clauses = []
88194 query_params = []
89 signature_query_clauses = []
90 signature_query_params = []
91195
92196 if include_all_devices is False:
93197 include_deleted_devices = False
98202 for (user_id, device_id) in query_list:
99203 query_clause = "user_id = ?"
100204 query_params.append(user_id)
101 signature_query_clause = "target_user_id = ?"
102 signature_query_params.append(user_id)
103205
104206 if device_id is not None:
105207 query_clause += " AND device_id = ?"
106208 query_params.append(device_id)
107 signature_query_clause += " AND target_device_id = ?"
108 signature_query_params.append(device_id)
109
110 signature_query_clause += " AND user_id = ?"
111 signature_query_params.append(user_id)
112209
113210 query_clauses.append(query_clause)
114 signature_query_clauses.append(signature_query_clause)
115211
116212 sql = (
117213 "SELECT user_id, device_id, "
118 " d.display_name AS device_display_name, "
214 " d.display_name, "
119215 " k.key_json"
120216 " FROM devices d"
121217 " %s JOIN e2e_device_keys_json k USING (user_id, device_id)"
126222 )
127223
128224 txn.execute(sql, query_params)
129 rows = self.db_pool.cursor_to_dict(txn)
130
131 result = {}
132 for row in rows:
225
226 result = {} # type: Dict[str, Dict[str, Optional[DeviceKeyLookupResult]]]
227 for (user_id, device_id, display_name, key_json) in txn:
133228 if include_deleted_devices:
134 deleted_devices.remove((row["user_id"], row["device_id"]))
135 result.setdefault(row["user_id"], {})[row["device_id"]] = row
229 deleted_devices.remove((user_id, device_id))
230 result.setdefault(user_id, {})[device_id] = DeviceKeyLookupResult(
231 display_name, db_to_json(key_json) if key_json else None
232 )
136233
137234 if include_deleted_devices:
138235 for user_id, device_id in deleted_devices:
139236 result.setdefault(user_id, {})[device_id] = None
140237
141 # get signatures on the device
142 signature_sql = ("SELECT * FROM e2e_cross_signing_signatures WHERE %s") % (
238 return result
239
240 def _get_e2e_cross_signing_signatures_for_devices_txn(
241 self, txn: Cursor, device_query: Iterable[Tuple[str, str]]
242 ) -> List[Tuple[str, str, str, str]]:
243 """Get cross-signing signatures for a given list of devices
244
245 Returns signatures made by the owners of the devices.
246
247 Returns: a list of results; each entry in the list is a tuple of
248 (user_id, key_id, target_device_id, signature).
249 """
250 signature_query_clauses = []
251 signature_query_params = []
252
253 for (user_id, device_id) in device_query:
254 signature_query_clauses.append(
255 "target_user_id = ? AND target_device_id = ? AND user_id = ?"
256 )
257 signature_query_params.extend([user_id, device_id, user_id])
258
259 signature_sql = """
260 SELECT user_id, key_id, target_device_id, signature
261 FROM e2e_cross_signing_signatures WHERE %s
262 """ % (
143263 " OR ".join("(" + q + ")" for q in signature_query_clauses)
144264 )
145265
146266 txn.execute(signature_sql, signature_query_params)
147 rows = self.db_pool.cursor_to_dict(txn)
148
149 # add each cross-signing signature to the correct device in the result dict.
150 for row in rows:
151 signing_user_id = row["user_id"]
152 signing_key_id = row["key_id"]
153 target_user_id = row["target_user_id"]
154 target_device_id = row["target_device_id"]
155 signature = row["signature"]
156
157 target_user_result = result.get(target_user_id)
158 if not target_user_result:
159 continue
160
161 target_device_result = target_user_result.get(target_device_id)
162 if not target_device_result:
163 # note that target_device_result will be None for deleted devices.
164 continue
165
166 target_device_signatures = target_device_result.setdefault("signatures", {})
167 signing_user_signatures = target_device_signatures.setdefault(
168 signing_user_id, {}
169 )
170 signing_user_signatures[signing_key_id] = signature
171
172 log_kv(result)
173 return result
267 return txn.fetchall()
174268
175269 async def get_e2e_one_time_keys(
176270 self, user_id: str, device_id: str, key_ids: List[str]
248342 )
249343
250344 @cached(max_entries=10000)
251 def count_e2e_one_time_keys(self, user_id, device_id):
345 async def count_e2e_one_time_keys(
346 self, user_id: str, device_id: str
347 ) -> Dict[str, int]:
252348 """ Count the number of one time keys the server has for a device
253349 Returns:
254 Dict mapping from algorithm to number of keys for that algorithm.
350 A mapping from algorithm to number of keys for that algorithm.
255351 """
256352
257353 def _count_e2e_one_time_keys(txn):
266362 result[algorithm] = key_count
267363 return result
268364
269 return self.db_pool.runInteraction(
365 return await self.db_pool.runInteraction(
270366 "count_e2e_one_time_keys", _count_e2e_one_time_keys
271367 )
272368
304400 list_name="user_ids",
305401 num_args=1,
306402 )
307 def _get_bare_e2e_cross_signing_keys_bulk(
403 async def _get_bare_e2e_cross_signing_keys_bulk(
308404 self, user_ids: List[str]
309405 ) -> Dict[str, Dict[str, dict]]:
310406 """Returns the cross-signing keys for a set of users. The output of this
312408 the signatures for the calling user need to be fetched.
313409
314410 Args:
315 user_ids (list[str]): the users whose keys are being requested
316
317 Returns:
318 dict[str, dict[str, dict]]: mapping from user ID to key type to key
319 data. If a user's cross-signing keys were not found, either
320 their user ID will not be in the dict, or their user ID will map
321 to None.
322
323 """
324 return self.db_pool.runInteraction(
411 user_ids: the users whose keys are being requested
412
413 Returns:
414 A mapping from user ID to key type to key data. If a user's cross-signing
415 keys were not found, either their user ID will not be in the dict, or
416 their user ID will map to None.
417
418 """
419 return await self.db_pool.runInteraction(
325420 "get_bare_e2e_cross_signing_keys_bulk",
326421 self._get_bare_e2e_cross_signing_keys_bulk_txn,
327422 user_ids,
537632 _get_all_user_signature_changes_for_remotes_txn,
538633 )
539634
635 @abc.abstractmethod
636 def get_device_stream_token(self) -> int:
637 """Get the current stream id from the _device_list_id_gen"""
638 ...
639
540640
541641 class EndToEndKeyStore(EndToEndKeyWorkerStore, SQLBaseStore):
542 def set_e2e_device_keys(self, user_id, device_id, time_now, device_keys):
642 async def set_e2e_device_keys(
643 self, user_id: str, device_id: str, time_now: int, device_keys: JsonDict
644 ) -> bool:
543645 """Stores device keys for a device. Returns whether there was a change
544646 or the keys were already in the database.
545647 """
575677 log_kv({"message": "Device keys stored."})
576678 return True
577679
578 return self.db_pool.runInteraction(
680 return await self.db_pool.runInteraction(
579681 "set_e2e_device_keys", _set_e2e_device_keys_txn
580682 )
581683
582 def claim_e2e_one_time_keys(self, query_list):
583 """Take a list of one time keys out of the database"""
684 async def claim_e2e_one_time_keys(
685 self, query_list: Iterable[Tuple[str, str, str]]
686 ) -> Dict[str, Dict[str, Dict[str, bytes]]]:
687 """Take a list of one time keys out of the database.
688
689 Args:
690 query_list: An iterable of tuples of (user ID, device ID, algorithm).
691
692 Returns:
693 A map of user ID -> a map device ID -> a map of key ID -> JSON bytes.
694 """
584695
585696 @trace
586697 def _claim_e2e_one_time_keys(txn):
616727 )
617728 return result
618729
619 return self.db_pool.runInteraction(
730 return await self.db_pool.runInteraction(
620731 "claim_e2e_one_time_keys", _claim_e2e_one_time_keys
621732 )
622733
623 def delete_e2e_keys_by_device(self, user_id, device_id):
734 async def delete_e2e_keys_by_device(self, user_id: str, device_id: str) -> None:
624735 def delete_e2e_keys_by_device_txn(txn):
625736 log_kv(
626737 {
643754 txn, self.count_e2e_one_time_keys, (user_id, device_id)
644755 )
645756
646 return self.db_pool.runInteraction(
757 await self.db_pool.runInteraction(
647758 "delete_e2e_keys_by_device", delete_e2e_keys_by_device_txn
648759 )
649760
650 def _set_e2e_cross_signing_key_txn(self, txn, user_id, key_type, key):
761 def _set_e2e_cross_signing_key_txn(self, txn, user_id, key_type, key, stream_id):
651762 """Set a user's cross-signing key.
652763
653764 Args:
657768 for a master key, 'self_signing' for a self-signing key, or
658769 'user_signing' for a user-signing key
659770 key (dict): the key data
771 stream_id (int)
660772 """
661773 # the 'key' dict will look something like:
662774 # {
694806 )
695807
696808 # and finally, store the key itself
697 with self._cross_signing_id_gen.get_next() as stream_id:
698 self.db_pool.simple_insert_txn(
699 txn,
700 "e2e_cross_signing_keys",
701 values={
702 "user_id": user_id,
703 "keytype": key_type,
704 "keydata": json_encoder.encode(key),
705 "stream_id": stream_id,
706 },
707 )
809 self.db_pool.simple_insert_txn(
810 txn,
811 "e2e_cross_signing_keys",
812 values={
813 "user_id": user_id,
814 "keytype": key_type,
815 "keydata": json_encoder.encode(key),
816 "stream_id": stream_id,
817 },
818 )
708819
709820 self._invalidate_cache_and_stream(
710821 txn, self._get_bare_e2e_cross_signing_keys, (user_id,)
711822 )
712823
713 def set_e2e_cross_signing_key(self, user_id, key_type, key):
824 async def set_e2e_cross_signing_key(self, user_id, key_type, key):
714825 """Set a user's cross-signing key.
715826
716827 Args:
718829 key_type (str): the type of cross-signing key to set
719830 key (dict): the key data
720831 """
721 return self.db_pool.runInteraction(
722 "add_e2e_cross_signing_key",
723 self._set_e2e_cross_signing_key_txn,
724 user_id,
725 key_type,
726 key,
727 )
728
729 def store_e2e_cross_signing_signatures(self, user_id, signatures):
832
833 with await self._cross_signing_id_gen.get_next() as stream_id:
834 return await self.db_pool.runInteraction(
835 "add_e2e_cross_signing_key",
836 self._set_e2e_cross_signing_key_txn,
837 user_id,
838 key_type,
839 key,
840 stream_id,
841 )
842
843 async def store_e2e_cross_signing_signatures(
844 self, user_id: str, signatures: "Iterable[SignatureListItem]"
845 ) -> None:
730846 """Stores cross-signing signatures.
731847
732848 Args:
733 user_id (str): the user who made the signatures
734 signatures (iterable[SignatureListItem]): signatures to add
735 """
736 return self.db_pool.simple_insert_many(
849 user_id: the user who made the signatures
850 signatures: signatures to add
851 """
852 await self.db_pool.simple_insert_many(
737853 "e2e_cross_signing_signatures",
738854 [
739855 {
1414 import itertools
1515 import logging
1616 from queue import Empty, PriorityQueue
17 from typing import Dict, Iterable, List, Optional, Set, Tuple
17 from typing import Dict, Iterable, List, Set, Tuple
1818
1919 from synapse.api.errors import StoreError
20 from synapse.events import EventBase
2021 from synapse.metrics.background_process_metrics import run_as_background_process
2122 from synapse.storage._base import SQLBaseStore, make_in_list_sql_clause
22 from synapse.storage.database import DatabasePool
23 from synapse.storage.database import DatabasePool, LoggingTransaction
2324 from synapse.storage.databases.main.events_worker import EventsWorkerStore
2425 from synapse.storage.databases.main.signatures import SignatureWorkerStore
26 from synapse.types import Collection
2527 from synapse.util.caches.descriptors import cached
2628 from synapse.util.iterutils import batch_iter
2729
2931
3032
3133 class EventFederationWorkerStore(EventsWorkerStore, SignatureWorkerStore, SQLBaseStore):
32 def get_auth_chain(self, event_ids, include_given=False):
34 async def get_auth_chain(
35 self, event_ids: Collection[str], include_given: bool = False
36 ) -> List[EventBase]:
3337 """Get auth events for given event_ids. The events *must* be state events.
3438
3539 Args:
36 event_ids (list): state events
37 include_given (bool): include the given events in result
40 event_ids: state events
41 include_given: include the given events in result
3842
3943 Returns:
4044 list of events
4145 """
42 return self.get_auth_chain_ids(
46 event_ids = await self.get_auth_chain_ids(
4347 event_ids, include_given=include_given
44 ).addCallback(self.get_events_as_list)
45
46 def get_auth_chain_ids(
47 self,
48 event_ids: List[str],
49 include_given: bool = False,
50 ignore_events: Optional[Set[str]] = None,
51 ):
48 )
49 return await self.get_events_as_list(event_ids)
50
51 async def get_auth_chain_ids(
52 self, event_ids: Collection[str], include_given: bool = False,
53 ) -> List[str]:
5254 """Get auth events for given event_ids. The events *must* be state events.
5355
5456 Args:
5557 event_ids: state events
5658 include_given: include the given events in result
57 ignore_events: Set of events to exclude from the returned auth
58 chain. This is useful if the caller will just discard the
59 given events anyway, and saves us from figuring out their auth
60 chains if not required.
6159
6260 Returns:
63 list of event_ids
64 """
65 return self.db_pool.runInteraction(
61 An awaitable which resolve to a list of event_ids
62 """
63 return await self.db_pool.runInteraction(
6664 "get_auth_chain_ids",
6765 self._get_auth_chain_ids_txn,
6866 event_ids,
6967 include_given,
70 ignore_events,
71 )
72
73 def _get_auth_chain_ids_txn(self, txn, event_ids, include_given, ignore_events):
74 if ignore_events is None:
75 ignore_events = set()
76
68 )
69
70 def _get_auth_chain_ids_txn(
71 self, txn: LoggingTransaction, event_ids: Collection[str], include_given: bool
72 ) -> List[str]:
7773 if include_given:
7874 results = set(event_ids)
7975 else:
8076 results = set()
8177
82 base_sql = "SELECT auth_id FROM event_auth WHERE "
78 base_sql = "SELECT DISTINCT auth_id FROM event_auth WHERE "
8379
8480 front = set(event_ids)
8581 while front:
9187 txn.execute(base_sql + clause, args)
9288 new_front.update(r[0] for r in txn)
9389
94 new_front -= ignore_events
9590 new_front -= results
9691
9792 front = new_front
9994
10095 return list(results)
10196
102 def get_auth_chain_difference(self, state_sets: List[Set[str]]):
97 async def get_auth_chain_difference(self, state_sets: List[Set[str]]) -> Set[str]:
10398 """Given sets of state events figure out the auth chain difference (as
10499 per state res v2 algorithm).
105100
108103 chain.
109104
110105 Returns:
111 Deferred[Set[str]]
112 """
113
114 return self.db_pool.runInteraction(
106 The set of the difference in auth chains.
107 """
108
109 return await self.db_pool.runInteraction(
115110 "get_auth_chain_difference",
116111 self._get_auth_chain_difference_txn,
117112 state_sets,
256251 # Return all events where not all sets can reach them.
257252 return {eid for eid, n in event_to_missing_sets.items() if n}
258253
259 def get_oldest_events_in_room(self, room_id):
260 return self.db_pool.runInteraction(
261 "get_oldest_events_in_room", self._get_oldest_events_in_room_txn, room_id
262 )
263
264 def get_oldest_events_with_depth_in_room(self, room_id):
265 return self.db_pool.runInteraction(
254 async def get_oldest_events_with_depth_in_room(self, room_id):
255 return await self.db_pool.runInteraction(
266256 "get_oldest_events_with_depth_in_room",
267257 self.get_oldest_events_with_depth_in_room_txn,
268258 room_id,
302292 else:
303293 return max(row["depth"] for row in rows)
304294
305 def _get_oldest_events_in_room_txn(self, txn, room_id):
306 return self.db_pool.simple_select_onecol_txn(
307 txn,
308 table="event_backward_extremities",
309 keyvalues={"room_id": room_id},
310 retcol="event_id",
311 )
312
313 def get_prev_events_for_room(self, room_id: str):
295 async def get_prev_events_for_room(self, room_id: str) -> List[str]:
314296 """
315297 Gets a subset of the current forward extremities in the given room.
316298
318300 events which refer to hundreds of prev_events.
319301
320302 Args:
321 room_id (str): room_id
303 room_id: room_id
322304
323305 Returns:
324 Deferred[List[str]]: the event ids of the forward extremites
325
326 """
327
328 return self.db_pool.runInteraction(
306 The event ids of the forward extremities.
307
308 """
309
310 return await self.db_pool.runInteraction(
329311 "get_prev_events_for_room", self._get_prev_events_for_room_txn, room_id
330312 )
331313
345327
346328 return [row[0] for row in txn]
347329
348 def get_rooms_with_many_extremities(self, min_count, limit, room_id_filter):
330 async def get_rooms_with_many_extremities(
331 self, min_count: int, limit: int, room_id_filter: Iterable[str]
332 ) -> List[str]:
349333 """Get the top rooms with at least N extremities.
350334
351335 Args:
352 min_count (int): The minimum number of extremities
353 limit (int): The maximum number of rooms to return.
354 room_id_filter (iterable[str]): room_ids to exclude from the results
336 min_count: The minimum number of extremities
337 limit: The maximum number of rooms to return.
338 room_id_filter: room_ids to exclude from the results
355339
356340 Returns:
357 Deferred[list]: At most `limit` room IDs that have at least
358 `min_count` extremities, sorted by extremity count.
341 At most `limit` room IDs that have at least `min_count` extremities,
342 sorted by extremity count.
359343 """
360344
361345 def _get_rooms_with_many_extremities_txn(txn):
380364 txn.execute(sql, query_args)
381365 return [room_id for room_id, in txn]
382366
383 return self.db_pool.runInteraction(
367 return await self.db_pool.runInteraction(
384368 "get_rooms_with_many_extremities", _get_rooms_with_many_extremities_txn
385369 )
386370
387371 @cached(max_entries=5000, iterable=True)
388 def get_latest_event_ids_in_room(self, room_id):
389 return self.db_pool.simple_select_onecol(
372 async def get_latest_event_ids_in_room(self, room_id: str) -> List[str]:
373 return await self.db_pool.simple_select_onecol(
390374 table="event_forward_extremities",
391375 keyvalues={"room_id": room_id},
392376 retcol="event_id",
393377 desc="get_latest_event_ids_in_room",
394378 )
395379
396 def get_min_depth(self, room_id):
397 """ For hte given room, get the minimum depth we have seen for it.
398 """
399 return self.db_pool.runInteraction(
380 async def get_min_depth(self, room_id: str) -> int:
381 """For the given room, get the minimum depth we have seen for it.
382 """
383 return await self.db_pool.runInteraction(
400384 "get_min_depth", self._get_min_depth_interaction, room_id
401385 )
402386
411395
412396 return int(min_depth) if min_depth is not None else None
413397
414 def get_forward_extremeties_for_room(self, room_id, stream_ordering):
398 async def get_forward_extremeties_for_room(
399 self, room_id: str, stream_ordering: int
400 ) -> List[str]:
415401 """For a given room_id and stream_ordering, return the forward
416402 extremeties of the room at that point in "time".
417403
419405 stream_orderings from that point.
420406
421407 Args:
422 room_id (str):
423 stream_ordering (int):
408 room_id:
409 stream_ordering:
424410
425411 Returns:
426 deferred, which resolves to a list of event_ids
412 A list of event_ids
427413 """
428414 # We want to make the cache more effective, so we clamp to the last
429415 # change before the given ordering.
439425 if last_change > self.stream_ordering_month_ago:
440426 stream_ordering = min(last_change, stream_ordering)
441427
442 return self._get_forward_extremeties_for_room(room_id, stream_ordering)
428 return await self._get_forward_extremeties_for_room(room_id, stream_ordering)
443429
444430 @cached(max_entries=5000, num_args=2)
445 def _get_forward_extremeties_for_room(self, room_id, stream_ordering):
431 async def _get_forward_extremeties_for_room(self, room_id, stream_ordering):
446432 """For a given room_id and stream_ordering, return the forward
447433 extremeties of the room at that point in "time".
448434
467453 txn.execute(sql, (stream_ordering, room_id))
468454 return [event_id for event_id, in txn]
469455
470 return self.db_pool.runInteraction(
456 return await self.db_pool.runInteraction(
471457 "get_forward_extremeties_for_room", get_forward_extremeties_for_room_txn
472458 )
473459
474 def get_backfill_events(self, room_id, event_list, limit):
460 async def get_backfill_events(self, room_id: str, event_list: list, limit: int):
475461 """Get a list of Events for a given topic that occurred before (and
476462 including) the events in event_list. Return a list of max size `limit`
477463
478464 Args:
479 txn
480 room_id (str)
481 event_list (list)
482 limit (int)
483 """
484 return (
485 self.db_pool.runInteraction(
486 "get_backfill_events",
487 self._get_backfill_events,
488 room_id,
489 event_list,
490 limit,
491 )
492 .addCallback(self.get_events_as_list)
493 .addCallback(lambda l: sorted(l, key=lambda e: -e.depth))
494 )
465 room_id
466 event_list
467 limit
468 """
469 event_ids = await self.db_pool.runInteraction(
470 "get_backfill_events",
471 self._get_backfill_events,
472 room_id,
473 event_list,
474 limit,
475 )
476 events = await self.get_events_as_list(event_ids)
477 return sorted(events, key=lambda e: -e.depth)
495478
496479 def _get_backfill_events(self, txn, room_id, event_list, limit):
497480 logger.debug("_get_backfill_events: %s, %r, %s", room_id, event_list, limit)
552535 latest_events,
553536 limit,
554537 )
555 events = await self.get_events_as_list(ids)
556 return events
538 return await self.get_events_as_list(ids)
557539
558540 def _get_missing_events(self, txn, room_id, earliest_events, latest_events, limit):
559541
651633 _delete_old_forward_extrem_cache_txn,
652634 )
653635
654 def clean_room_for_join(self, room_id):
655 return self.db_pool.runInteraction(
636 async def clean_room_for_join(self, room_id):
637 return await self.db_pool.runInteraction(
656638 "clean_room_for_join", self._clean_room_for_join_txn, room_id
657639 )
658640
1414 # limitations under the License.
1515
1616 import logging
17 from typing import List
17 from typing import Dict, List, Optional, Tuple, Union
18
19 import attr
1820
1921 from synapse.metrics.background_process_metrics import run_as_background_process
2022 from synapse.storage._base import LoggingTransaction, SQLBaseStore, db_to_json
2123 from synapse.storage.database import DatabasePool
2224 from synapse.util import json_encoder
23 from synapse.util.caches.descriptors import cachedInlineCallbacks
25 from synapse.util.caches.descriptors import cached
2426
2527 logger = logging.getLogger(__name__)
2628
8587 self._rotate_delay = 3
8688 self._rotate_count = 10000
8789
88 @cachedInlineCallbacks(num_args=3, tree=True, max_entries=5000)
89 def get_unread_event_push_actions_by_room_for_user(
90 self, room_id, user_id, last_read_event_id
91 ):
92 ret = yield self.db_pool.runInteraction(
90 @cached(num_args=3, tree=True, max_entries=5000)
91 async def get_unread_event_push_actions_by_room_for_user(
92 self, room_id: str, user_id: str, last_read_event_id: Optional[str],
93 ) -> Dict[str, int]:
94 """Get the notification count, the highlight count and the unread message count
95 for a given user in a given room after the given read receipt.
96
97 Note that this function assumes the user to be a current member of the room,
98 since it's either called by the sync handler to handle joined room entries, or by
99 the HTTP pusher to calculate the badge of unread joined rooms.
100
101 Args:
102 room_id: The room to retrieve the counts in.
103 user_id: The user to retrieve the counts for.
104 last_read_event_id: The event associated with the latest read receipt for
105 this user in this room. None if no receipt for this user in this room.
106
107 Returns
108 A dict containing the counts mentioned earlier in this docstring,
109 respectively under the keys "notify_count", "highlight_count" and
110 "unread_count".
111 """
112 return await self.db_pool.runInteraction(
93113 "get_unread_event_push_actions_by_room",
94114 self._get_unread_counts_by_receipt_txn,
95115 room_id,
96116 user_id,
97117 last_read_event_id,
98118 )
99 return ret
100119
101120 def _get_unread_counts_by_receipt_txn(
102 self, txn, room_id, user_id, last_read_event_id
121 self, txn, room_id, user_id, last_read_event_id,
103122 ):
104 sql = (
105 "SELECT stream_ordering"
106 " FROM events"
107 " WHERE room_id = ? AND event_id = ?"
108 )
109 txn.execute(sql, (room_id, last_read_event_id))
110 results = txn.fetchall()
111 if len(results) == 0:
112 return {"notify_count": 0, "highlight_count": 0}
113
114 stream_ordering = results[0][0]
123 stream_ordering = None
124
125 if last_read_event_id is not None:
126 stream_ordering = self.get_stream_id_for_event_txn(
127 txn, last_read_event_id, allow_none=True,
128 )
129
130 if stream_ordering is None:
131 # Either last_read_event_id is None, or it's an event we don't have (e.g.
132 # because it's been purged), in which case retrieve the stream ordering for
133 # the latest membership event from this user in this room (which we assume is
134 # a join).
135 event_id = self.db_pool.simple_select_one_onecol_txn(
136 txn=txn,
137 table="local_current_membership",
138 keyvalues={"room_id": room_id, "user_id": user_id},
139 retcol="event_id",
140 )
141
142 stream_ordering = self.get_stream_id_for_event_txn(txn, event_id)
115143
116144 return self._get_unread_counts_by_pos_txn(
117145 txn, room_id, user_id, stream_ordering
118146 )
119147
120148 def _get_unread_counts_by_pos_txn(self, txn, room_id, user_id, stream_ordering):
121
122 # First get number of notifications.
123 # We don't need to put a notif=1 clause as all rows always have
124 # notif=1
125149 sql = (
126 "SELECT count(*)"
150 "SELECT"
151 " COUNT(CASE WHEN notif = 1 THEN 1 END),"
152 " COUNT(CASE WHEN highlight = 1 THEN 1 END),"
153 " COUNT(CASE WHEN unread = 1 THEN 1 END)"
127154 " FROM event_push_actions ea"
128 " WHERE"
129 " user_id = ?"
130 " AND room_id = ?"
131 " AND stream_ordering > ?"
155 " WHERE user_id = ?"
156 " AND room_id = ?"
157 " AND stream_ordering > ?"
132158 )
133159
134160 txn.execute(sql, (user_id, room_id, stream_ordering))
135161 row = txn.fetchone()
136 notify_count = row[0] if row else 0
162
163 (notif_count, highlight_count, unread_count) = (0, 0, 0)
164
165 if row:
166 (notif_count, highlight_count, unread_count) = row
137167
138168 txn.execute(
139169 """
140 SELECT notif_count FROM event_push_summary
141 WHERE room_id = ? AND user_id = ? AND stream_ordering > ?
142 """,
170 SELECT notif_count, unread_count FROM event_push_summary
171 WHERE room_id = ? AND user_id = ? AND stream_ordering > ?
172 """,
143173 (room_id, user_id, stream_ordering),
144174 )
145 rows = txn.fetchall()
146 if rows:
147 notify_count += rows[0][0]
148
149 # Now get the number of highlights
150 sql = (
151 "SELECT count(*)"
152 " FROM event_push_actions ea"
153 " WHERE"
154 " highlight = 1"
155 " AND user_id = ?"
156 " AND room_id = ?"
157 " AND stream_ordering > ?"
158 )
159
160 txn.execute(sql, (user_id, room_id, stream_ordering))
161175 row = txn.fetchone()
162 highlight_count = row[0] if row else 0
163
164 return {"notify_count": notify_count, "highlight_count": highlight_count}
176
177 if row:
178 notif_count += row[0]
179
180 if row[1] is not None:
181 # The unread_count column of event_push_summary is NULLable, so we need
182 # to make sure we don't try increasing the unread counts if it's NULL
183 # for this row.
184 unread_count += row[1]
185
186 return {
187 "notify_count": notif_count,
188 "unread_count": unread_count,
189 "highlight_count": highlight_count,
190 }
165191
166192 async def get_push_action_users_in_range(
167193 self, min_stream_ordering, max_stream_ordering
169195 def f(txn):
170196 sql = (
171197 "SELECT DISTINCT(user_id) FROM event_push_actions WHERE"
172 " stream_ordering >= ? AND stream_ordering <= ?"
198 " stream_ordering >= ? AND stream_ordering <= ? AND notif = 1"
173199 )
174200 txn.execute(sql, (min_stream_ordering, max_stream_ordering))
175201 return [r[0] for r in txn]
222248 " AND ep.user_id = ?"
223249 " AND ep.stream_ordering > ?"
224250 " AND ep.stream_ordering <= ?"
251 " AND ep.notif = 1"
225252 " ORDER BY ep.stream_ordering ASC LIMIT ?"
226253 )
227254 args = [user_id, user_id, min_stream_ordering, max_stream_ordering, limit]
250277 " AND ep.user_id = ?"
251278 " AND ep.stream_ordering > ?"
252279 " AND ep.stream_ordering <= ?"
280 " AND ep.notif = 1"
253281 " ORDER BY ep.stream_ordering ASC LIMIT ?"
254282 )
255283 args = [user_id, user_id, min_stream_ordering, max_stream_ordering, limit]
324352 " AND ep.user_id = ?"
325353 " AND ep.stream_ordering > ?"
326354 " AND ep.stream_ordering <= ?"
355 " AND ep.notif = 1"
327356 " ORDER BY ep.stream_ordering DESC LIMIT ?"
328357 )
329358 args = [user_id, user_id, min_stream_ordering, max_stream_ordering, limit]
352381 " AND ep.user_id = ?"
353382 " AND ep.stream_ordering > ?"
354383 " AND ep.stream_ordering <= ?"
384 " AND ep.notif = 1"
355385 " ORDER BY ep.stream_ordering DESC LIMIT ?"
356386 )
357387 args = [user_id, user_id, min_stream_ordering, max_stream_ordering, limit]
383413 # Now return the first `limit`
384414 return notifs[:limit]
385415
386 def get_if_maybe_push_in_range_for_user(self, user_id, min_stream_ordering):
416 async def get_if_maybe_push_in_range_for_user(
417 self, user_id: str, min_stream_ordering: int
418 ) -> bool:
387419 """A fast check to see if there might be something to push for the
388420 user since the given stream ordering. May return false positives.
389421
390422 Useful to know whether to bother starting a pusher on start up or not.
391423
392424 Args:
393 user_id (str)
394 min_stream_ordering (int)
425 user_id
426 min_stream_ordering
395427
396428 Returns:
397 Deferred[bool]: True if there may be push to process, False if
398 there definitely isn't.
429 True if there may be push to process, False if there definitely isn't.
399430 """
400431
401432 def _get_if_maybe_push_in_range_for_user_txn(txn):
402433 sql = """
403434 SELECT 1 FROM event_push_actions
404 WHERE user_id = ? AND stream_ordering > ?
435 WHERE user_id = ? AND stream_ordering > ? AND notif = 1
405436 LIMIT 1
406437 """
407438
408439 txn.execute(sql, (user_id, min_stream_ordering))
409440 return bool(txn.fetchone())
410441
411 return self.db_pool.runInteraction(
442 return await self.db_pool.runInteraction(
412443 "get_if_maybe_push_in_range_for_user",
413444 _get_if_maybe_push_in_range_for_user_txn,
414445 )
415446
416 async def add_push_actions_to_staging(self, event_id, user_id_actions):
447 async def add_push_actions_to_staging(
448 self,
449 event_id: str,
450 user_id_actions: Dict[str, List[Union[dict, str]]],
451 count_as_unread: bool,
452 ) -> None:
417453 """Add the push actions for the event to the push action staging area.
418454
419455 Args:
420 event_id (str)
421 user_id_actions (dict[str, list[dict|str])]): A dictionary mapping
422 user_id to list of push actions, where an action can either be
423 a string or dict.
424
425 Returns:
426 Deferred
427 """
428
456 event_id
457 user_id_actions: A mapping of user_id to list of push actions, where
458 an action can either be a string or dict.
459 count_as_unread: Whether this event should increment unread counts.
460 """
429461 if not user_id_actions:
430462 return
431463
432464 # This is a helper function for generating the necessary tuple that
433 # can be used to inert into the `event_push_actions_staging` table.
465 # can be used to insert into the `event_push_actions_staging` table.
434466 def _gen_entry(user_id, actions):
435467 is_highlight = 1 if _action_has_highlight(actions) else 0
468 notif = 1 if "notify" in actions else 0
436469 return (
437470 event_id, # event_id column
438471 user_id, # user_id column
439472 _serialize_action(actions, is_highlight), # actions column
440 1, # notif column
473 notif, # notif column
441474 is_highlight, # highlight column
475 int(count_as_unread), # unread column
442476 )
443477
444478 def _add_push_actions_to_staging_txn(txn):
447481
448482 sql = """
449483 INSERT INTO event_push_actions_staging
450 (event_id, user_id, actions, notif, highlight)
451 VALUES (?, ?, ?, ?, ?)
484 (event_id, user_id, actions, notif, highlight, unread)
485 VALUES (?, ?, ?, ?, ?, ?)
452486 """
453487
454488 txn.executemany(
507541 "Found stream ordering 1 day ago: it's %d", self.stream_ordering_day_ago
508542 )
509543
510 def find_first_stream_ordering_after_ts(self, ts):
544 async def find_first_stream_ordering_after_ts(self, ts: int) -> int:
511545 """Gets the stream ordering corresponding to a given timestamp.
512546
513547 Specifically, finds the stream_ordering of the first event that was
516550 relatively slow.
517551
518552 Args:
519 ts (int): timestamp in millis
553 ts: timestamp in millis
520554
521555 Returns:
522 Deferred[int]: stream ordering of the first event received on/after
523 the timestamp
524 """
525 return self.db_pool.runInteraction(
556 stream ordering of the first event received on/after the timestamp
557 """
558 return await self.db_pool.runInteraction(
526559 "_find_first_stream_ordering_after_ts_txn",
527560 self._find_first_stream_ordering_after_ts_txn,
528561 ts,
610643 "SELECT e.received_ts"
611644 " FROM event_push_actions AS ep"
612645 " JOIN events e ON ep.room_id = e.room_id AND ep.event_id = e.event_id"
613 " WHERE ep.stream_ordering > ?"
646 " WHERE ep.stream_ordering > ? AND notif = 1"
614647 " ORDER BY ep.stream_ordering ASC"
615648 " LIMIT 1"
616649 )
674707 " FROM event_push_actions epa, events e"
675708 " WHERE epa.event_id = e.event_id"
676709 " AND epa.user_id = ? %s"
710 " AND epa.notif = 1"
677711 " ORDER BY epa.stream_ordering DESC"
678712 " LIMIT ?" % (before_clause,)
679713 )
813847 # Calculate the new counts that should be upserted into event_push_summary
814848 sql = """
815849 SELECT user_id, room_id,
816 coalesce(old.notif_count, 0) + upd.notif_count,
850 coalesce(old.%s, 0) + upd.cnt,
817851 upd.stream_ordering,
818852 old.user_id
819853 FROM (
820 SELECT user_id, room_id, count(*) as notif_count,
854 SELECT user_id, room_id, count(*) as cnt,
821855 max(stream_ordering) as stream_ordering
822856 FROM event_push_actions
823857 WHERE ? <= stream_ordering AND stream_ordering < ?
824858 AND highlight = 0
859 AND %s = 1
825860 GROUP BY user_id, room_id
826861 ) AS upd
827862 LEFT JOIN event_push_summary AS old USING (user_id, room_id)
828863 """
829864
830 txn.execute(sql, (old_rotate_stream_ordering, rotate_to_stream_ordering))
831 rows = txn.fetchall()
832
833 logger.info("Rotating notifications, handling %d rows", len(rows))
865 # First get the count of unread messages.
866 txn.execute(
867 sql % ("unread_count", "unread"),
868 (old_rotate_stream_ordering, rotate_to_stream_ordering),
869 )
870
871 # We need to merge results from the two requests (the one that retrieves the
872 # unread count and the one that retrieves the notifications count) into a single
873 # object because we might not have the same amount of rows in each of them. To do
874 # this, we use a dict indexed on the user ID and room ID to make it easier to
875 # populate.
876 summaries = {} # type: Dict[Tuple[str, str], _EventPushSummary]
877 for row in txn:
878 summaries[(row[0], row[1])] = _EventPushSummary(
879 unread_count=row[2],
880 stream_ordering=row[3],
881 old_user_id=row[4],
882 notif_count=0,
883 )
884
885 # Then get the count of notifications.
886 txn.execute(
887 sql % ("notif_count", "notif"),
888 (old_rotate_stream_ordering, rotate_to_stream_ordering),
889 )
890
891 for row in txn:
892 if (row[0], row[1]) in summaries:
893 summaries[(row[0], row[1])].notif_count = row[2]
894 else:
895 # Because the rules on notifying are different than the rules on marking
896 # a message unread, we might end up with messages that notify but aren't
897 # marked unread, so we might not have a summary for this (user, room)
898 # tuple to complete.
899 summaries[(row[0], row[1])] = _EventPushSummary(
900 unread_count=0,
901 stream_ordering=row[3],
902 old_user_id=row[4],
903 notif_count=row[2],
904 )
905
906 logger.info("Rotating notifications, handling %d rows", len(summaries))
834907
835908 # If the `old.user_id` above is NULL then we know there isn't already an
836909 # entry in the table, so we simply insert it. Otherwise we update the
840913 table="event_push_summary",
841914 values=[
842915 {
843 "user_id": row[0],
844 "room_id": row[1],
845 "notif_count": row[2],
846 "stream_ordering": row[3],
916 "user_id": user_id,
917 "room_id": room_id,
918 "notif_count": summary.notif_count,
919 "unread_count": summary.unread_count,
920 "stream_ordering": summary.stream_ordering,
847921 }
848 for row in rows
849 if row[4] is None
922 for ((user_id, room_id), summary) in summaries.items()
923 if summary.old_user_id is None
850924 ],
851925 )
852926
853927 txn.executemany(
854928 """
855 UPDATE event_push_summary SET notif_count = ?, stream_ordering = ?
929 UPDATE event_push_summary
930 SET notif_count = ?, unread_count = ?, stream_ordering = ?
856931 WHERE user_id = ? AND room_id = ?
857932 """,
858 ((row[2], row[3], row[0], row[1]) for row in rows if row[4] is not None),
933 (
934 (
935 summary.notif_count,
936 summary.unread_count,
937 summary.stream_ordering,
938 user_id,
939 room_id,
940 )
941 for ((user_id, room_id), summary) in summaries.items()
942 if summary.old_user_id is not None
943 ),
859944 )
860945
861946 txn.execute(
881966 pass
882967
883968 return False
969
970
971 @attr.s
972 class _EventPushSummary:
973 """Summary of pending event push actions for a given user in a given room.
974 Used in _rotate_notifs_before_txn to manipulate results from event_push_actions.
975 """
976
977 unread_count = attr.ib(type=int)
978 stream_ordering = attr.ib(type=int)
979 old_user_id = attr.ib(type=str)
980 notif_count = attr.ib(type=int)
1616 import itertools
1717 import logging
1818 from collections import OrderedDict, namedtuple
19 from typing import TYPE_CHECKING, Dict, Iterable, List, Tuple
19 from typing import TYPE_CHECKING, Dict, Iterable, List, Set, Tuple
2020
2121 import attr
2222 from prometheus_client import Counter
23
24 from twisted.internet import defer
2523
2624 import synapse.metrics
2725 from synapse.api.constants import EventContentFields, EventTypes, RelationTypes
112110 hs.config.worker.writers.events == hs.get_instance_name()
113111 ), "Can only instantiate EventsStore on master"
114112
115 @defer.inlineCallbacks
116 def _persist_events_and_state_updates(
113 async def _persist_events_and_state_updates(
117114 self,
118115 events_and_contexts: List[Tuple[EventBase, EventContext]],
119116 current_state_for_room: Dict[str, StateMap[str]],
120117 state_delta_for_room: Dict[str, DeltaState],
121118 new_forward_extremeties: Dict[str, List[str]],
122119 backfilled: bool = False,
123 ):
120 ) -> None:
124121 """Persist a set of events alongside updates to the current state and
125122 forward extremities tables.
126123
135132 backfilled
136133
137134 Returns:
138 Deferred: resolves when the events have been persisted
135 Resolves when the events have been persisted
139136 """
140137
141138 # We want to calculate the stream orderings as late as possible, as
155152 # Note: Multiple instances of this function cannot be in flight at
156153 # the same time for the same room.
157154 if backfilled:
158 stream_ordering_manager = self._backfill_id_gen.get_next_mult(
155 stream_ordering_manager = await self._backfill_id_gen.get_next_mult(
159156 len(events_and_contexts)
160157 )
161158 else:
162 stream_ordering_manager = self._stream_id_gen.get_next_mult(
159 stream_ordering_manager = await self._stream_id_gen.get_next_mult(
163160 len(events_and_contexts)
164161 )
165162
167164 for (event, context), stream in zip(events_and_contexts, stream_orderings):
168165 event.internal_metadata.stream_ordering = stream
169166
170 yield self.db_pool.runInteraction(
167 await self.db_pool.runInteraction(
171168 "persist_events",
172169 self._persist_events_txn,
173170 events_and_contexts=events_and_contexts,
205202 (room_id,), list(latest_event_ids)
206203 )
207204
208 @defer.inlineCallbacks
209 def _get_events_which_are_prevs(self, event_ids):
205 async def _get_events_which_are_prevs(self, event_ids: Iterable[str]) -> List[str]:
210206 """Filter the supplied list of event_ids to get those which are prev_events of
211207 existing (non-outlier/rejected) events.
212208
213209 Args:
214 event_ids (Iterable[str]): event ids to filter
210 event_ids: event ids to filter
215211
216212 Returns:
217 Deferred[List[str]]: filtered event ids
213 Filtered event ids
218214 """
219215 results = []
220216
239235 results.extend(r[0] for r in txn if not db_to_json(r[1]).get("soft_failed"))
240236
241237 for chunk in batch_iter(event_ids, 100):
242 yield self.db_pool.runInteraction(
238 await self.db_pool.runInteraction(
243239 "_get_events_which_are_prevs", _get_events_which_are_prevs_txn, chunk
244240 )
245241
246242 return results
247243
248 @defer.inlineCallbacks
249 def _get_prevs_before_rejected(self, event_ids):
244 async def _get_prevs_before_rejected(self, event_ids: Iterable[str]) -> Set[str]:
250245 """Get soft-failed ancestors to remove from the extremities.
251246
252247 Given a set of events, find all those that have been soft-failed or
258253 are separated by soft failed events.
259254
260255 Args:
261 event_ids (Iterable[str]): Events to find prev events for. Note
262 that these must have already been persisted.
256 event_ids: Events to find prev events for. Note that these must have
257 already been persisted.
263258
264259 Returns:
265 Deferred[set[str]]
260 The previous events.
266261 """
267262
268263 # The set of event_ids to return. This includes all soft-failed events
303298 existing_prevs.add(prev_event_id)
304299
305300 for chunk in batch_iter(event_ids, 100):
306 yield self.db_pool.runInteraction(
301 await self.db_pool.runInteraction(
307302 "_get_prevs_before_rejected", _get_prevs_before_rejected_txn, chunk
308303 )
309304
13001295 sql = """
13011296 INSERT INTO event_push_actions (
13021297 room_id, event_id, user_id, actions, stream_ordering,
1303 topological_ordering, notif, highlight
1304 )
1305 SELECT ?, event_id, user_id, actions, ?, ?, notif, highlight
1298 topological_ordering, notif, highlight, unread
1299 )
1300 SELECT ?, event_id, user_id, actions, ?, ?, notif, highlight, unread
13061301 FROM event_push_actions_staging
13071302 WHERE event_id = ?
13081303 """
1414
1515 import logging
1616
17 from twisted.internet import defer
18
1917 from synapse.api.constants import EventContentFields
2018 from synapse.storage._base import SQLBaseStore, db_to_json, make_in_list_sql_clause
2119 from synapse.storage.database import DatabasePool
9391 where_clause="NOT have_censored",
9492 )
9593
96 @defer.inlineCallbacks
97 def _background_reindex_fields_sender(self, progress, batch_size):
94 async def _background_reindex_fields_sender(self, progress, batch_size):
9895 target_min_stream_id = progress["target_min_stream_id_inclusive"]
9996 max_stream_id = progress["max_stream_id_exclusive"]
10097 rows_inserted = progress.get("rows_inserted", 0)
154151
155152 return len(rows)
156153
157 result = yield self.db_pool.runInteraction(
154 result = await self.db_pool.runInteraction(
158155 self.EVENT_FIELDS_SENDER_URL_UPDATE_NAME, reindex_txn
159156 )
160157
161158 if not result:
162 yield self.db_pool.updates._end_background_update(
159 await self.db_pool.updates._end_background_update(
163160 self.EVENT_FIELDS_SENDER_URL_UPDATE_NAME
164161 )
165162
166163 return result
167164
168 @defer.inlineCallbacks
169 def _background_reindex_origin_server_ts(self, progress, batch_size):
165 async def _background_reindex_origin_server_ts(self, progress, batch_size):
170166 target_min_stream_id = progress["target_min_stream_id_inclusive"]
171167 max_stream_id = progress["max_stream_id_exclusive"]
172168 rows_inserted = progress.get("rows_inserted", 0)
233229
234230 return len(rows_to_update)
235231
236 result = yield self.db_pool.runInteraction(
232 result = await self.db_pool.runInteraction(
237233 self.EVENT_ORIGIN_SERVER_TS_NAME, reindex_search_txn
238234 )
239235
240236 if not result:
241 yield self.db_pool.updates._end_background_update(
237 await self.db_pool.updates._end_background_update(
242238 self.EVENT_ORIGIN_SERVER_TS_NAME
243239 )
244240
245241 return result
246242
247 @defer.inlineCallbacks
248 def _cleanup_extremities_bg_update(self, progress, batch_size):
243 async def _cleanup_extremities_bg_update(self, progress, batch_size):
249244 """Background update to clean out extremities that should have been
250245 deleted previously.
251246
413408
414409 return len(original_set)
415410
416 num_handled = yield self.db_pool.runInteraction(
411 num_handled = await self.db_pool.runInteraction(
417412 "_cleanup_extremities_bg_update", _cleanup_extremities_bg_update_txn
418413 )
419414
420415 if not num_handled:
421 yield self.db_pool.updates._end_background_update(
416 await self.db_pool.updates._end_background_update(
422417 self.DELETE_SOFT_FAILED_EXTREMITIES
423418 )
424419
425420 def _drop_table_txn(txn):
426421 txn.execute("DROP TABLE _extremities_to_check")
427422
428 yield self.db_pool.runInteraction(
423 await self.db_pool.runInteraction(
429424 "_cleanup_extremities_bg_update_drop_table", _drop_table_txn
430425 )
431426
432427 return num_handled
433428
434 @defer.inlineCallbacks
435 def _redactions_received_ts(self, progress, batch_size):
429 async def _redactions_received_ts(self, progress, batch_size):
436430 """Handles filling out the `received_ts` column in redactions.
437431 """
438432 last_event_id = progress.get("last_event_id", "")
479473
480474 return len(rows)
481475
482 count = yield self.db_pool.runInteraction(
476 count = await self.db_pool.runInteraction(
483477 "_redactions_received_ts", _redactions_received_ts_txn
484478 )
485479
486480 if not count:
487 yield self.db_pool.updates._end_background_update("redactions_received_ts")
481 await self.db_pool.updates._end_background_update("redactions_received_ts")
488482
489483 return count
490484
491 @defer.inlineCallbacks
492 def _event_fix_redactions_bytes(self, progress, batch_size):
485 async def _event_fix_redactions_bytes(self, progress, batch_size):
493486 """Undoes hex encoded censored redacted event JSON.
494487 """
495488
510503
511504 txn.execute("DROP INDEX redactions_censored_redacts")
512505
513 yield self.db_pool.runInteraction(
506 await self.db_pool.runInteraction(
514507 "_event_fix_redactions_bytes", _event_fix_redactions_bytes_txn
515508 )
516509
517 yield self.db_pool.updates._end_background_update("event_fix_redactions_bytes")
510 await self.db_pool.updates._end_background_update("event_fix_redactions_bytes")
518511
519512 return 1
520513
521 @defer.inlineCallbacks
522 def _event_store_labels(self, progress, batch_size):
514 async def _event_store_labels(self, progress, batch_size):
523515 """Background update handler which will store labels for existing events."""
524516 last_event_id = progress.get("last_event_id", "")
525517
574566
575567 return nbrows
576568
577 num_rows = yield self.db_pool.runInteraction(
569 num_rows = await self.db_pool.runInteraction(
578570 desc="event_store_labels", func=_event_store_labels_txn
579571 )
580572
581573 if not num_rows:
582 yield self.db_pool.updates._end_background_update("event_store_labels")
574 await self.db_pool.updates._end_background_update("event_store_labels")
583575
584576 return num_rows
1818 import logging
1919 import threading
2020 from collections import namedtuple
21 from typing import List, Optional, Tuple
21 from typing import Dict, Iterable, List, Optional, Tuple, overload
2222
2323 from constantly import NamedConstant, Names
24 from typing_extensions import Literal
2425
2526 from twisted.internet import defer
2627
3132 EventFormatVersions,
3233 RoomVersions,
3334 )
34 from synapse.events import make_event_from_dict
35 from synapse.events import EventBase, make_event_from_dict
3536 from synapse.events.utils import prune_event
3637 from synapse.logging.context import PreserveLoggingContext, current_context
3738 from synapse.metrics.background_process_metrics import run_as_background_process
4142 from synapse.storage._base import SQLBaseStore, db_to_json, make_in_list_sql_clause
4243 from synapse.storage.database import DatabasePool
4344 from synapse.storage.util.id_generators import StreamIdGenerator
44 from synapse.types import get_domain_from_id
45 from synapse.util.caches.descriptors import Cache, cached, cachedInlineCallbacks
45 from synapse.types import Collection, get_domain_from_id
46 from synapse.util.caches.descriptors import Cache, cached
4647 from synapse.util.iterutils import batch_iter
4748 from synapse.util.metrics import Measure
4849
111112
112113 def process_replication_rows(self, stream_name, instance_name, token, rows):
113114 if stream_name == EventsStream.NAME:
114 self._stream_id_gen.advance(token)
115 self._stream_id_gen.advance(instance_name, token)
115116 elif stream_name == BackfillStream.NAME:
116 self._backfill_id_gen.advance(-token)
117 self._backfill_id_gen.advance(instance_name, -token)
117118
118119 super().process_replication_rows(stream_name, instance_name, token, rows)
119120
120 def get_received_ts(self, event_id):
121 async def get_received_ts(self, event_id: str) -> Optional[int]:
121122 """Get received_ts (when it was persisted) for the event.
122123
123124 Raises an exception for unknown events.
124125
125126 Args:
126 event_id (str)
127
128 Returns:
129 Deferred[int|None]: Timestamp in milliseconds, or None for events
130 that were persisted before received_ts was implemented.
131 """
132 return self.db_pool.simple_select_one_onecol(
127 event_id: The event ID to query.
128
129 Returns:
130 Timestamp in milliseconds, or None for events that were persisted
131 before received_ts was implemented.
132 """
133 return await self.db_pool.simple_select_one_onecol(
133134 table="events",
134135 keyvalues={"event_id": event_id},
135136 retcol="received_ts",
136137 desc="get_received_ts",
137138 )
138139
139 def get_received_ts_by_stream_pos(self, stream_ordering):
140 """Given a stream ordering get an approximate timestamp of when it
141 happened.
142
143 This is done by simply taking the received ts of the first event that
144 has a stream ordering greater than or equal to the given stream pos.
145 If none exists returns the current time, on the assumption that it must
146 have happened recently.
147
148 Args:
149 stream_ordering (int)
150
151 Returns:
152 Deferred[int]
153 """
154
155 def _get_approximate_received_ts_txn(txn):
156 sql = """
157 SELECT received_ts FROM events
158 WHERE stream_ordering >= ?
159 LIMIT 1
160 """
161
162 txn.execute(sql, (stream_ordering,))
163 row = txn.fetchone()
164 if row and row[0]:
165 ts = row[0]
166 else:
167 ts = self.clock.time_msec()
168
169 return ts
170
171 return self.db_pool.runInteraction(
172 "get_approximate_received_ts", _get_approximate_received_ts_txn
173 )
174
175 @defer.inlineCallbacks
176 def get_event(
140 # Inform mypy that if allow_none is False (the default) then get_event
141 # always returns an EventBase.
142 @overload
143 async def get_event(
144 self,
145 event_id: str,
146 redact_behaviour: EventRedactBehaviour = EventRedactBehaviour.REDACT,
147 get_prev_content: bool = False,
148 allow_rejected: bool = False,
149 allow_none: Literal[False] = False,
150 check_room_id: Optional[str] = None,
151 ) -> EventBase:
152 ...
153
154 @overload
155 async def get_event(
156 self,
157 event_id: str,
158 redact_behaviour: EventRedactBehaviour = EventRedactBehaviour.REDACT,
159 get_prev_content: bool = False,
160 allow_rejected: bool = False,
161 allow_none: Literal[True] = False,
162 check_room_id: Optional[str] = None,
163 ) -> Optional[EventBase]:
164 ...
165
166 async def get_event(
177167 self,
178168 event_id: str,
179169 redact_behaviour: EventRedactBehaviour = EventRedactBehaviour.REDACT,
181171 allow_rejected: bool = False,
182172 allow_none: bool = False,
183173 check_room_id: Optional[str] = None,
184 ):
174 ) -> Optional[EventBase]:
185175 """Get an event from the database by event_id.
186176
187177 Args:
206196 If there is a mismatch, behave as per allow_none.
207197
208198 Returns:
209 Deferred[EventBase|None]
199 The event, or None if the event was not found.
210200 """
211201 if not isinstance(event_id, str):
212202 raise TypeError("Invalid event event_id %r" % (event_id,))
213203
214 events = yield self.get_events_as_list(
204 events = await self.get_events_as_list(
215205 [event_id],
216206 redact_behaviour=redact_behaviour,
217207 get_prev_content=get_prev_content,
229219
230220 return event
231221
232 @defer.inlineCallbacks
233 def get_events(
222 async def get_events(
234223 self,
235 event_ids: List[str],
224 event_ids: Iterable[str],
236225 redact_behaviour: EventRedactBehaviour = EventRedactBehaviour.REDACT,
237226 get_prev_content: bool = False,
238227 allow_rejected: bool = False,
239 ):
228 ) -> Dict[str, EventBase]:
240229 """Get events from the database
241230
242231 Args:
255244 omits rejeted events from the response.
256245
257246 Returns:
258 Deferred : Dict from event_id to event.
259 """
260 events = yield self.get_events_as_list(
247 A mapping from event_id to event.
248 """
249 events = await self.get_events_as_list(
261250 event_ids,
262251 redact_behaviour=redact_behaviour,
263252 get_prev_content=get_prev_content,
266255
267256 return {e.event_id: e for e in events}
268257
269 @defer.inlineCallbacks
270 def get_events_as_list(
258 async def get_events_as_list(
271259 self,
272 event_ids: List[str],
260 event_ids: Collection[str],
273261 redact_behaviour: EventRedactBehaviour = EventRedactBehaviour.REDACT,
274262 get_prev_content: bool = False,
275263 allow_rejected: bool = False,
276 ):
264 ) -> List[EventBase]:
277265 """Get events from the database and return in a list in the same order
278266 as given by `event_ids` arg.
279267
294282 omits rejected events from the response.
295283
296284 Returns:
297 Deferred[list[EventBase]]: List of events fetched from the database. The
298 events are in the same order as `event_ids` arg.
285 List of events fetched from the database. The events are in the same
286 order as `event_ids` arg.
299287
300288 Note that the returned list may be smaller than the list of event
301289 IDs if not all events could be fetched.
305293 return []
306294
307295 # there may be duplicates so we cast the list to a set
308 event_entry_map = yield self._get_events_from_cache_or_db(
296 event_entry_map = await self._get_events_from_cache_or_db(
309297 set(event_ids), allow_rejected=allow_rejected
310298 )
311299
340328 continue
341329
342330 redacted_event_id = entry.event.redacts
343 event_map = yield self._get_events_from_cache_or_db([redacted_event_id])
331 event_map = await self._get_events_from_cache_or_db([redacted_event_id])
344332 original_event_entry = event_map.get(redacted_event_id)
345333 if not original_event_entry:
346334 # we don't have the redacted event (or it was rejected).
406394
407395 if get_prev_content:
408396 if "replaces_state" in event.unsigned:
409 prev = yield self.get_event(
397 prev = await self.get_event(
410398 event.unsigned["replaces_state"],
411399 get_prev_content=False,
412400 allow_none=True,
418406
419407 return events
420408
421 @defer.inlineCallbacks
422 def _get_events_from_cache_or_db(self, event_ids, allow_rejected=False):
409 async def _get_events_from_cache_or_db(self, event_ids, allow_rejected=False):
423410 """Fetch a bunch of events from the cache or the database.
424411
425412 If events are pulled from the database, they will be cached for future lookups.
434421 rejected events are omitted from the response.
435422
436423 Returns:
437 Deferred[Dict[str, _EventCacheEntry]]:
424 Dict[str, _EventCacheEntry]:
438425 map from event id to result
439426 """
440427 event_entry_map = self._get_events_from_cache(
452439 # the events have been redacted, and if so pulling the redaction event out
453440 # of the database to check it.
454441 #
455 missing_events = yield self._get_events_from_db(
442 missing_events = await self._get_events_from_db(
456443 missing_events_ids, allow_rejected=allow_rejected
457444 )
458445
560547 with PreserveLoggingContext():
561548 self.hs.get_reactor().callFromThread(fire, event_list, e)
562549
563 @defer.inlineCallbacks
564 def _get_events_from_db(self, event_ids, allow_rejected=False):
550 async def _get_events_from_db(self, event_ids, allow_rejected=False):
565551 """Fetch a bunch of events from the database.
566552
567553 Returned events will be added to the cache for future lookups.
575561 rejected events are omitted from the response.
576562
577563 Returns:
578 Deferred[Dict[str, _EventCacheEntry]]:
564 Dict[str, _EventCacheEntry]:
579565 map from event id to result. May return extra events which
580566 weren't asked for.
581567 """
583569 events_to_fetch = event_ids
584570
585571 while events_to_fetch:
586 row_map = yield self._enqueue_events(events_to_fetch)
572 row_map = await self._enqueue_events(events_to_fetch)
587573
588574 # we need to recursively fetch any redactions of those events
589575 redaction_ids = set()
609595 if not allow_rejected and rejected_reason:
610596 continue
611597
612 d = db_to_json(row["json"])
613 internal_metadata = db_to_json(row["internal_metadata"])
598 # If the event or metadata cannot be parsed, log the error and act
599 # as if the event is unknown.
600 try:
601 d = db_to_json(row["json"])
602 except ValueError:
603 logger.error("Unable to parse json from event: %s", event_id)
604 continue
605 try:
606 internal_metadata = db_to_json(row["internal_metadata"])
607 except ValueError:
608 logger.error(
609 "Unable to parse internal_metadata from event: %s", event_id
610 )
611 continue
614612
615613 format_version = row["format_version"]
616614 if format_version is None:
621619 room_version_id = row["room_version_id"]
622620
623621 if not room_version_id:
624 # this should only happen for out-of-band membership events
625 if not internal_metadata.get("out_of_band_membership"):
626 logger.warning(
627 "Room %s for event %s is unknown", d["room_id"], event_id
622 # this should only happen for out-of-band membership events which
623 # arrived before #6983 landed. For all other events, we should have
624 # an entry in the 'rooms' table.
625 #
626 # However, the 'out_of_band_membership' flag is unreliable for older
627 # invites, so just accept it for all membership events.
628 #
629 if d["type"] != EventTypes.Member:
630 raise Exception(
631 "Room %s for event %s is unknown" % (d["room_id"], event_id)
628632 )
629 continue
630
631 # take a wild stab at the room version based on the event format
633
634 # so, assuming this is an out-of-band-invite that arrived before #6983
635 # landed, we know that the room version must be v5 or earlier (because
636 # v6 hadn't been invented at that point, so invites from such rooms
637 # would have been rejected.)
638 #
639 # The main reason we need to know the room version here (other than
640 # choosing the right python Event class) is in case the event later has
641 # to be redacted - and all the room versions up to v5 used the same
642 # redaction algorithm.
643 #
644 # So, the following approximations should be adequate.
645
632646 if format_version == EventFormatVersions.V1:
647 # if it's event format v1 then it must be room v1 or v2
633648 room_version = RoomVersions.V1
634649 elif format_version == EventFormatVersions.V2:
650 # if it's event format v2 then it must be room v3
635651 room_version = RoomVersions.V3
636652 else:
653 # if it's event format v3 then it must be room v4 or v5
637654 room_version = RoomVersions.V5
638655 else:
639656 room_version = KNOWN_ROOM_VERSIONS.get(room_version_id)
685702
686703 return result_map
687704
688 @defer.inlineCallbacks
689 def _enqueue_events(self, events):
705 async def _enqueue_events(self, events):
690706 """Fetches events from the database using the _event_fetch_list. This
691707 allows batch and bulk fetching of events - it allows us to fetch events
692708 without having to create a new transaction for each request for events.
695711 events (Iterable[str]): events to be fetched.
696712
697713 Returns:
698 Deferred[Dict[str, Dict]]: map from event id to row data from the database.
714 Dict[str, Dict]: map from event id to row data from the database.
699715 May contain events that weren't requested.
700716 """
701717
718734
719735 logger.debug("Loading %d events: %s", len(events), events)
720736 with PreserveLoggingContext():
721 row_map = yield events_d
737 row_map = await events_d
722738 logger.debug("Loaded %d events (%d rows)", len(events), len(row_map))
723739
724740 return row_map
806822
807823 return event_dict
808824
809 def _maybe_redact_event_row(self, original_ev, redactions, event_map):
825 def _maybe_redact_event_row(
826 self,
827 original_ev: EventBase,
828 redactions: Iterable[str],
829 event_map: Dict[str, EventBase],
830 ) -> Optional[EventBase]:
810831 """Given an event object and a list of possible redacting event ids,
811832 determine whether to honour any of those redactions and if so return a redacted
812833 event.
813834
814835 Args:
815 original_ev (EventBase):
816 redactions (iterable[str]): list of event ids of potential redaction events
817 event_map (dict[str, EventBase]): other events which have been fetched, in
818 which we can look up the redaaction events. Map from event id to event.
819
820 Returns:
821 Deferred[EventBase|None]: if the event should be redacted, a pruned
822 event object. Otherwise, None.
836 original_ev: The original event.
837 redactions: list of event ids of potential redaction events
838 event_map: other events which have been fetched, in which we can
839 look up the redaaction events. Map from event id to event.
840
841 Returns:
842 If the event should be redacted, a pruned event object. Otherwise, None.
823843 """
824844 if original_ev.type == "m.room.create":
825845 # we choose to ignore redactions of m.room.create events.
877897 # no valid redaction found for this event
878898 return None
879899
880 @defer.inlineCallbacks
881 def have_events_in_timeline(self, event_ids):
900 async def have_events_in_timeline(self, event_ids):
882901 """Given a list of event ids, check if we have already processed and
883902 stored them as non outliers.
884903 """
885 rows = yield self.db_pool.simple_select_many_batch(
904 rows = await self.db_pool.simple_select_many_batch(
886905 table="events",
887906 retcols=("event_id",),
888907 column="event_id",
893912
894913 return {r["event_id"] for r in rows}
895914
896 @defer.inlineCallbacks
897 def have_seen_events(self, event_ids):
915 async def have_seen_events(self, event_ids):
898916 """Given a list of event ids, check if we have already processed them.
899917
900918 Args:
901919 event_ids (iterable[str]):
902920
903921 Returns:
904 Deferred[set[str]]: The events we have already seen.
922 set[str]: The events we have already seen.
905923 """
906924 results = set()
907925
917935 # break the input up into chunks of 100
918936 input_iterator = iter(event_ids)
919937 for chunk in iter(lambda: list(itertools.islice(input_iterator, 100)), []):
920 yield self.db_pool.runInteraction(
938 await self.db_pool.runInteraction(
921939 "have_seen_events", have_seen_events_txn, chunk
922940 )
923941 return results
924
925 def _get_total_state_event_counts_txn(self, txn, room_id):
926 """
927 See get_total_state_event_counts.
928 """
929 # We join against the events table as that has an index on room_id
930 sql = """
931 SELECT COUNT(*) FROM state_events
932 INNER JOIN events USING (room_id, event_id)
933 WHERE room_id=?
934 """
935 txn.execute(sql, (room_id,))
936 row = txn.fetchone()
937 return row[0] if row else 0
938
939 def get_total_state_event_counts(self, room_id):
940 """
941 Gets the total number of state events in a room.
942
943 Args:
944 room_id (str)
945
946 Returns:
947 Deferred[int]
948 """
949 return self.db_pool.runInteraction(
950 "get_total_state_event_counts",
951 self._get_total_state_event_counts_txn,
952 room_id,
953 )
954942
955943 def _get_current_state_event_counts_txn(self, txn, room_id):
956944 """
961949 row = txn.fetchone()
962950 return row[0] if row else 0
963951
964 def get_current_state_event_counts(self, room_id):
952 async def get_current_state_event_counts(self, room_id: str) -> int:
965953 """
966954 Gets the current number of state events in a room.
967955
968956 Args:
969 room_id (str)
970
971 Returns:
972 Deferred[int]
973 """
974 return self.db_pool.runInteraction(
957 room_id: The room ID to query.
958
959 Returns:
960 The current number of state events.
961 """
962 return await self.db_pool.runInteraction(
975963 "get_current_state_event_counts",
976964 self._get_current_state_event_counts_txn,
977965 room_id,
978966 )
979967
980 @defer.inlineCallbacks
981 def get_room_complexity(self, room_id):
968 async def get_room_complexity(self, room_id):
982969 """
983970 Get a rough approximation of the complexity of the room. This is used by
984971 remote servers to decide whether they wish to join the room or not.
989976 room_id (str)
990977
991978 Returns:
992 Deferred[dict[str:int]] of complexity version to complexity.
993 """
994 state_events = yield self.get_current_state_event_counts(room_id)
979 dict[str:int] of complexity version to complexity.
980 """
981 state_events = await self.get_current_state_event_counts(room_id)
995982
996983 # Call this one "v1", so we can introduce new ones as we want to develop
997984 # it.
1007994 """The current maximum token that events have reached"""
1008995 return self._stream_id_gen.get_current_token()
1009996
1010 def get_all_new_forward_event_rows(self, last_id, current_id, limit):
997 async def get_all_new_forward_event_rows(
998 self, last_id: int, current_id: int, limit: int
999 ) -> List[Tuple]:
10111000 """Returns new events, for the Events replication stream
10121001
10131002 Args:
10151004 current_id: the maximum stream_id to return up to
10161005 limit: the maximum number of rows to return
10171006
1018 Returns: Deferred[List[Tuple]]
1007 Returns:
10191008 a list of events stream rows. Each tuple consists of a stream id as
10201009 the first element, followed by fields suitable for casting into an
10211010 EventsStreamRow.
10361025 txn.execute(sql, (last_id, current_id, limit))
10371026 return txn.fetchall()
10381027
1039 return self.db_pool.runInteraction(
1028 return await self.db_pool.runInteraction(
10401029 "get_all_new_forward_event_rows", get_all_new_forward_event_rows
10411030 )
10421031
1043 def get_ex_outlier_stream_rows(self, last_id, current_id):
1032 async def get_ex_outlier_stream_rows(
1033 self, last_id: int, current_id: int
1034 ) -> List[Tuple]:
10441035 """Returns de-outliered events, for the Events replication stream
10451036
10461037 Args:
10471038 last_id: the last stream_id from the previous batch.
10481039 current_id: the maximum stream_id to return up to
10491040
1050 Returns: Deferred[List[Tuple]]
1041 Returns:
10511042 a list of events stream rows. Each tuple consists of a stream id as
10521043 the first element, followed by fields suitable for casting into an
10531044 EventsStreamRow.
10701061 txn.execute(sql, (last_id, current_id))
10711062 return txn.fetchall()
10721063
1073 return self.db_pool.runInteraction(
1064 return await self.db_pool.runInteraction(
10741065 "get_ex_outlier_stream_rows", get_ex_outlier_stream_rows_txn
10751066 )
10761067
12211212
12221213 return rows, to_token, True
12231214
1224 @cached(num_args=5, max_entries=10)
1225 def get_all_new_events(
1226 self,
1227 last_backfill_id,
1228 last_forward_id,
1229 current_backfill_id,
1230 current_forward_id,
1231 limit,
1232 ):
1233 """Get all the new events that have arrived at the server either as
1234 new events or as backfilled events"""
1235 have_backfill_events = last_backfill_id != current_backfill_id
1236 have_forward_events = last_forward_id != current_forward_id
1237
1238 if not have_backfill_events and not have_forward_events:
1239 return defer.succeed(AllNewEventsResult([], [], [], [], []))
1240
1241 def get_all_new_events_txn(txn):
1242 sql = (
1243 "SELECT e.stream_ordering, e.event_id, e.room_id, e.type,"
1244 " state_key, redacts"
1245 " FROM events AS e"
1246 " LEFT JOIN redactions USING (event_id)"
1247 " LEFT JOIN state_events USING (event_id)"
1248 " WHERE ? < stream_ordering AND stream_ordering <= ?"
1249 " ORDER BY stream_ordering ASC"
1250 " LIMIT ?"
1251 )
1252 if have_forward_events:
1253 txn.execute(sql, (last_forward_id, current_forward_id, limit))
1254 new_forward_events = txn.fetchall()
1255
1256 if len(new_forward_events) == limit:
1257 upper_bound = new_forward_events[-1][0]
1258 else:
1259 upper_bound = current_forward_id
1260
1261 sql = (
1262 "SELECT event_stream_ordering, event_id, state_group"
1263 " FROM ex_outlier_stream"
1264 " WHERE ? > event_stream_ordering"
1265 " AND event_stream_ordering >= ?"
1266 " ORDER BY event_stream_ordering DESC"
1267 )
1268 txn.execute(sql, (last_forward_id, upper_bound))
1269 forward_ex_outliers = txn.fetchall()
1270 else:
1271 new_forward_events = []
1272 forward_ex_outliers = []
1273
1274 sql = (
1275 "SELECT -e.stream_ordering, e.event_id, e.room_id, e.type,"
1276 " state_key, redacts"
1277 " FROM events AS e"
1278 " LEFT JOIN redactions USING (event_id)"
1279 " LEFT JOIN state_events USING (event_id)"
1280 " WHERE ? > stream_ordering AND stream_ordering >= ?"
1281 " ORDER BY stream_ordering DESC"
1282 " LIMIT ?"
1283 )
1284 if have_backfill_events:
1285 txn.execute(sql, (-last_backfill_id, -current_backfill_id, limit))
1286 new_backfill_events = txn.fetchall()
1287
1288 if len(new_backfill_events) == limit:
1289 upper_bound = new_backfill_events[-1][0]
1290 else:
1291 upper_bound = current_backfill_id
1292
1293 sql = (
1294 "SELECT -event_stream_ordering, event_id, state_group"
1295 " FROM ex_outlier_stream"
1296 " WHERE ? > event_stream_ordering"
1297 " AND event_stream_ordering >= ?"
1298 " ORDER BY event_stream_ordering DESC"
1299 )
1300 txn.execute(sql, (-last_backfill_id, -upper_bound))
1301 backward_ex_outliers = txn.fetchall()
1302 else:
1303 new_backfill_events = []
1304 backward_ex_outliers = []
1305
1306 return AllNewEventsResult(
1307 new_forward_events,
1308 new_backfill_events,
1309 forward_ex_outliers,
1310 backward_ex_outliers,
1311 )
1312
1313 return self.db_pool.runInteraction("get_all_new_events", get_all_new_events_txn)
1314
13151215 async def is_event_after(self, event_id1, event_id2):
13161216 """Returns True if event_id1 is after event_id2 in the stream
13171217 """
13191219 to_2, so_2 = await self.get_event_ordering(event_id2)
13201220 return (to_1, so_1) > (to_2, so_2)
13211221
1322 @cachedInlineCallbacks(max_entries=5000)
1323 def get_event_ordering(self, event_id):
1324 res = yield self.db_pool.simple_select_one(
1222 @cached(max_entries=5000)
1223 async def get_event_ordering(self, event_id):
1224 res = await self.db_pool.simple_select_one(
13251225 table="events",
13261226 retcols=["topological_ordering", "stream_ordering"],
13271227 keyvalues={"event_id": event_id},
13331233
13341234 return (int(res["topological_ordering"]), int(res["stream_ordering"]))
13351235
1336 def get_next_event_to_expire(self):
1236 async def get_next_event_to_expire(self) -> Optional[Tuple[str, int]]:
13371237 """Retrieve the entry with the lowest expiry timestamp in the event_expiry
13381238 table, or None if there's no more event to expire.
13391239
1340 Returns: Deferred[Optional[Tuple[str, int]]]
1240 Returns:
13411241 A tuple containing the event ID as its first element and an expiry timestamp
13421242 as its second one, if there's at least one row in the event_expiry table.
13431243 None otherwise.
13531253
13541254 return txn.fetchone()
13551255
1356 return self.db_pool.runInteraction(
1256 return await self.db_pool.runInteraction(
13571257 desc="get_next_event_to_expire", func=get_next_event_to_expire_txn
13581258 )
1359
1360
1361 AllNewEventsResult = namedtuple(
1362 "AllNewEventsResult",
1363 [
1364 "new_forward_events",
1365 "new_backfill_events",
1366 "forward_ex_outliers",
1367 "backward_ex_outliers",
1368 ],
1369 )
1616
1717 from synapse.api.errors import Codes, SynapseError
1818 from synapse.storage._base import SQLBaseStore, db_to_json
19 from synapse.types import JsonDict
1920 from synapse.util.caches.descriptors import cached
2021
2122
3940
4041 return db_to_json(def_json)
4142
42 def add_user_filter(self, user_localpart, user_filter):
43 async def add_user_filter(self, user_localpart: str, user_filter: JsonDict) -> str:
4344 def_json = encode_canonical_json(user_filter)
4445
4546 # Need an atomic transaction to SELECT the maximal ID so far then
7071
7172 return filter_id
7273
73 return self.db_pool.runInteraction("add_user_filter", _do_txn)
74 return await self.db_pool.runInteraction("add_user_filter", _do_txn)
1313 # See the License for the specific language governing permissions and
1414 # limitations under the License.
1515
16 from typing import List, Optional, Tuple
16 from typing import Any, Dict, List, Optional, Tuple, Union
1717
1818 from synapse.api.errors import SynapseError
1919 from synapse.storage._base import SQLBaseStore, db_to_json
2727
2828
2929 class GroupServerWorkerStore(SQLBaseStore):
30 def get_group(self, group_id):
31 return self.db_pool.simple_select_one(
30 async def get_group(self, group_id: str) -> Optional[Dict[str, Any]]:
31 return await self.db_pool.simple_select_one(
3232 table="groups",
3333 keyvalues={"group_id": group_id},
3434 retcols=(
4343 desc="get_group",
4444 )
4545
46 def get_users_in_group(self, group_id, include_private=False):
46 async def get_users_in_group(
47 self, group_id: str, include_private: bool = False
48 ) -> List[Dict[str, Any]]:
4749 # TODO: Pagination
4850
4951 keyvalues = {"group_id": group_id}
5052 if not include_private:
5153 keyvalues["is_public"] = True
5254
53 return self.db_pool.simple_select_list(
55 return await self.db_pool.simple_select_list(
5456 table="group_users",
5557 keyvalues=keyvalues,
5658 retcols=("user_id", "is_public", "is_admin"),
5759 desc="get_users_in_group",
5860 )
5961
60 def get_invited_users_in_group(self, group_id):
62 async def get_invited_users_in_group(self, group_id: str) -> List[str]:
6163 # TODO: Pagination
6264
63 return self.db_pool.simple_select_onecol(
65 return await self.db_pool.simple_select_onecol(
6466 table="group_invites",
6567 keyvalues={"group_id": group_id},
6668 retcol="user_id",
6769 desc="get_invited_users_in_group",
6870 )
6971
70 def get_rooms_in_group(self, group_id: str, include_private: bool = False):
72 async def get_rooms_in_group(
73 self, group_id: str, include_private: bool = False
74 ) -> List[Dict[str, Union[str, bool]]]:
7175 """Retrieve the rooms that belong to a given group. Does not return rooms that
7276 lack members.
7377
7680 include_private: Whether to return private rooms in results
7781
7882 Returns:
79 Deferred[List[Dict[str, str|bool]]]: A list of dictionaries, each in the
80 form of:
83 A list of dictionaries, each in the form of:
8184
8285 {
8386 "room_id": "!a_room_id:example.com", # The ID of the room
114117 for room_id, is_public in txn
115118 ]
116119
117 return self.db_pool.runInteraction(
120 return await self.db_pool.runInteraction(
118121 "get_rooms_in_group", _get_rooms_in_group_txn
119122 )
120123
121 def get_rooms_for_summary_by_category(
124 async def get_rooms_for_summary_by_category(
122125 self, group_id: str, include_private: bool = False,
123 ):
126 ) -> Tuple[List[Dict[str, Any]], Dict[str, Any]]:
124127 """Get the rooms and categories that should be included in a summary request
125128
126129 Args:
128131 include_private: Whether to return private rooms in results
129132
130133 Returns:
131 Deferred[Tuple[List, Dict]]: A tuple containing:
134 A tuple containing:
132135
133136 * A list of dictionaries with the keys:
134137 * "room_id": str, the room ID
204207
205208 return rooms, categories
206209
207 return self.db_pool.runInteraction(
210 return await self.db_pool.runInteraction(
208211 "get_rooms_for_summary", _get_rooms_for_summary_txn
209212 )
210213
264267
265268 return role
266269
267 def get_local_groups_for_room(self, room_id):
270 async def get_local_groups_for_room(self, room_id: str) -> List[str]:
268271 """Get all of the local group that contain a given room
269272 Args:
270 room_id (str): The ID of a room
273 room_id: The ID of a room
271274 Returns:
272 Deferred[list[str]]: A twisted.Deferred containing a list of group ids
273 containing this room
274 """
275 return self.db_pool.simple_select_onecol(
275 A list of group ids containing this room
276 """
277 return await self.db_pool.simple_select_onecol(
276278 table="group_rooms",
277279 keyvalues={"room_id": room_id},
278280 retcol="group_id",
279281 desc="get_local_groups_for_room",
280282 )
281283
282 def get_users_for_summary_by_role(self, group_id, include_private=False):
284 async def get_users_for_summary_by_role(self, group_id, include_private=False):
283285 """Get the users and roles that should be included in a summary request
284286
285 Returns ([users], [roles])
287 Returns:
288 ([users], [roles])
286289 """
287290
288291 def _get_users_for_summary_txn(txn):
336339
337340 return users, roles
338341
339 return self.db_pool.runInteraction(
342 return await self.db_pool.runInteraction(
340343 "get_users_for_summary_by_role", _get_users_for_summary_txn
341344 )
342345
343 def is_user_in_group(self, user_id, group_id):
344 return self.db_pool.simple_select_one_onecol(
346 async def is_user_in_group(self, user_id: str, group_id: str) -> bool:
347 result = await self.db_pool.simple_select_one_onecol(
345348 table="group_users",
346349 keyvalues={"group_id": group_id, "user_id": user_id},
347350 retcol="user_id",
348351 allow_none=True,
349352 desc="is_user_in_group",
350 ).addCallback(lambda r: bool(r))
351
352 def is_user_admin_in_group(self, group_id, user_id):
353 return self.db_pool.simple_select_one_onecol(
353 )
354 return bool(result)
355
356 async def is_user_admin_in_group(
357 self, group_id: str, user_id: str
358 ) -> Optional[bool]:
359 return await self.db_pool.simple_select_one_onecol(
354360 table="group_users",
355361 keyvalues={"group_id": group_id, "user_id": user_id},
356362 retcol="is_admin",
358364 desc="is_user_admin_in_group",
359365 )
360366
361 def is_user_invited_to_local_group(self, group_id, user_id):
367 async def is_user_invited_to_local_group(
368 self, group_id: str, user_id: str
369 ) -> Optional[bool]:
362370 """Has the group server invited a user?
363371 """
364 return self.db_pool.simple_select_one_onecol(
372 return await self.db_pool.simple_select_one_onecol(
365373 table="group_invites",
366374 keyvalues={"group_id": group_id, "user_id": user_id},
367375 retcol="user_id",
369377 allow_none=True,
370378 )
371379
372 def get_users_membership_info_in_group(self, group_id, user_id):
380 async def get_users_membership_info_in_group(self, group_id, user_id):
373381 """Get a dict describing the membership of a user in a group.
374382
375383 Example if joined:
380388 "is_privileged": False,
381389 }
382390
383 Returns an empty dict if the user is not join/invite/etc
391 Returns:
392 An empty dict if the user is not join/invite/etc
384393 """
385394
386395 def _get_users_membership_in_group_txn(txn):
412421
413422 return {}
414423
415 return self.db_pool.runInteraction(
424 return await self.db_pool.runInteraction(
416425 "get_users_membership_info_in_group", _get_users_membership_in_group_txn
417426 )
418427
419 def get_publicised_groups_for_user(self, user_id):
428 async def get_publicised_groups_for_user(self, user_id: str) -> List[str]:
420429 """Get all groups a user is publicising
421430 """
422 return self.db_pool.simple_select_onecol(
431 return await self.db_pool.simple_select_onecol(
423432 table="local_group_membership",
424433 keyvalues={"user_id": user_id, "membership": "join", "is_publicised": True},
425434 retcol="group_id",
426435 desc="get_publicised_groups_for_user",
427436 )
428437
429 def get_attestations_need_renewals(self, valid_until_ms):
438 async def get_attestations_need_renewals(self, valid_until_ms):
430439 """Get all attestations that need to be renewed until givent time
431440 """
432441
438447 txn.execute(sql, (valid_until_ms,))
439448 return self.db_pool.cursor_to_dict(txn)
440449
441 return self.db_pool.runInteraction(
450 return await self.db_pool.runInteraction(
442451 "get_attestations_need_renewals", _get_attestations_need_renewals_txn
443452 )
444453
460469
461470 return None
462471
463 def get_joined_groups(self, user_id):
464 return self.db_pool.simple_select_onecol(
472 async def get_joined_groups(self, user_id: str) -> List[str]:
473 return await self.db_pool.simple_select_onecol(
465474 table="local_group_membership",
466475 keyvalues={"user_id": user_id, "membership": "join"},
467476 retcol="group_id",
468477 desc="get_joined_groups",
469478 )
470479
471 def get_all_groups_for_user(self, user_id, now_token):
480 async def get_all_groups_for_user(self, user_id, now_token):
472481 def _get_all_groups_for_user_txn(txn):
473482 sql = """
474483 SELECT group_id, type, membership, u.content
488497 for row in txn
489498 ]
490499
491 return self.db_pool.runInteraction(
500 return await self.db_pool.runInteraction(
492501 "get_all_groups_for_user", _get_all_groups_for_user_txn
493502 )
494503
579588
580589
581590 class GroupServerStore(GroupServerWorkerStore):
582 def set_group_join_policy(self, group_id, join_policy):
591 async def set_group_join_policy(self, group_id: str, join_policy: str) -> None:
583592 """Set the join policy of a group.
584593
585594 join_policy can be one of:
586595 * "invite"
587596 * "open"
588597 """
589 return self.db_pool.simple_update_one(
598 await self.db_pool.simple_update_one(
590599 table="groups",
591600 keyvalues={"group_id": group_id},
592601 updatevalues={"join_policy": join_policy},
593602 desc="set_group_join_policy",
594603 )
595604
596 def add_room_to_summary(self, group_id, room_id, category_id, order, is_public):
597 return self.db_pool.runInteraction(
605 async def add_room_to_summary(
606 self,
607 group_id: str,
608 room_id: str,
609 category_id: str,
610 order: int,
611 is_public: Optional[bool],
612 ) -> None:
613 """Add (or update) room's entry in summary.
614
615 Args:
616 group_id
617 room_id
618 category_id: If not None then adds the category to the end of
619 the summary if its not already there.
620 order: If not None inserts the room at that position, e.g. an order
621 of 1 will put the room first. Otherwise, the room gets added to
622 the end.
623 is_public
624 """
625 await self.db_pool.runInteraction(
598626 "add_room_to_summary",
599627 self._add_room_to_summary_txn,
600628 group_id,
605633 )
606634
607635 def _add_room_to_summary_txn(
608 self, txn, group_id, room_id, category_id, order, is_public
609 ):
636 self,
637 txn,
638 group_id: str,
639 room_id: str,
640 category_id: str,
641 order: int,
642 is_public: Optional[bool],
643 ) -> None:
610644 """Add (or update) room's entry in summary.
611645
612646 Args:
613 group_id (str)
614 room_id (str)
615 category_id (str): If not None then adds the category to the end of
616 the summary if its not already there. [Optional]
617 order (int): If not None inserts the room at that position, e.g.
618 an order of 1 will put the room first. Otherwise, the room gets
619 added to the end.
647 txn
648 group_id
649 room_id
650 category_id: If not None then adds the category to the end of
651 the summary if its not already there.
652 order: If not None inserts the room at that position, e.g. an order
653 of 1 will put the room first. Otherwise, the room gets added to
654 the end.
655 is_public
620656 """
621657 room_in_group = self.db_pool.simple_select_one_onecol_txn(
622658 txn,
721757 },
722758 )
723759
724 def remove_room_from_summary(self, group_id, room_id, category_id):
760 async def remove_room_from_summary(
761 self, group_id: str, room_id: str, category_id: str
762 ) -> int:
725763 if category_id is None:
726764 category_id = _DEFAULT_CATEGORY_ID
727765
728 return self.db_pool.simple_delete(
766 return await self.db_pool.simple_delete(
729767 table="group_summary_rooms",
730768 keyvalues={
731769 "group_id": group_id,
735773 desc="remove_room_from_summary",
736774 )
737775
738 def upsert_group_category(self, group_id, category_id, profile, is_public):
776 async def upsert_group_category(
777 self,
778 group_id: str,
779 category_id: str,
780 profile: Optional[JsonDict],
781 is_public: Optional[bool],
782 ) -> None:
739783 """Add/update room category for group
740784 """
741785 insertion_values = {}
751795 else:
752796 update_values["is_public"] = is_public
753797
754 return self.db_pool.simple_upsert(
798 await self.db_pool.simple_upsert(
755799 table="group_room_categories",
756800 keyvalues={"group_id": group_id, "category_id": category_id},
757801 values=update_values,
759803 desc="upsert_group_category",
760804 )
761805
762 def remove_group_category(self, group_id, category_id):
763 return self.db_pool.simple_delete(
806 async def remove_group_category(self, group_id: str, category_id: str) -> int:
807 return await self.db_pool.simple_delete(
764808 table="group_room_categories",
765809 keyvalues={"group_id": group_id, "category_id": category_id},
766810 desc="remove_group_category",
767811 )
768812
769 def upsert_group_role(self, group_id, role_id, profile, is_public):
813 async def upsert_group_role(
814 self,
815 group_id: str,
816 role_id: str,
817 profile: Optional[JsonDict],
818 is_public: Optional[bool],
819 ) -> None:
770820 """Add/remove user role
771821 """
772822 insertion_values = {}
782832 else:
783833 update_values["is_public"] = is_public
784834
785 return self.db_pool.simple_upsert(
835 await self.db_pool.simple_upsert(
786836 table="group_roles",
787837 keyvalues={"group_id": group_id, "role_id": role_id},
788838 values=update_values,
790840 desc="upsert_group_role",
791841 )
792842
793 def remove_group_role(self, group_id, role_id):
794 return self.db_pool.simple_delete(
843 async def remove_group_role(self, group_id: str, role_id: str) -> int:
844 return await self.db_pool.simple_delete(
795845 table="group_roles",
796846 keyvalues={"group_id": group_id, "role_id": role_id},
797847 desc="remove_group_role",
798848 )
799849
800 def add_user_to_summary(self, group_id, user_id, role_id, order, is_public):
801 return self.db_pool.runInteraction(
850 async def add_user_to_summary(
851 self,
852 group_id: str,
853 user_id: str,
854 role_id: str,
855 order: int,
856 is_public: Optional[bool],
857 ) -> None:
858 """Add (or update) user's entry in summary.
859
860 Args:
861 group_id
862 user_id
863 role_id: If not None then adds the role to the end of the summary if
864 its not already there.
865 order: If not None inserts the user at that position, e.g. an order
866 of 1 will put the user first. Otherwise, the user gets added to
867 the end.
868 is_public
869 """
870 await self.db_pool.runInteraction(
802871 "add_user_to_summary",
803872 self._add_user_to_summary_txn,
804873 group_id,
809878 )
810879
811880 def _add_user_to_summary_txn(
812 self, txn, group_id, user_id, role_id, order, is_public
881 self,
882 txn,
883 group_id: str,
884 user_id: str,
885 role_id: str,
886 order: int,
887 is_public: Optional[bool],
813888 ):
814889 """Add (or update) user's entry in summary.
815890
816891 Args:
817 group_id (str)
818 user_id (str)
819 role_id (str): If not None then adds the role to the end of
820 the summary if its not already there. [Optional]
821 order (int): If not None inserts the user at that position, e.g.
822 an order of 1 will put the user first. Otherwise, the user gets
823 added to the end.
892 txn
893 group_id
894 user_id
895 role_id: If not None then adds the role to the end of the summary if
896 its not already there.
897 order: If not None inserts the user at that position, e.g. an order
898 of 1 will put the user first. Otherwise, the user gets added to
899 the end.
900 is_public
824901 """
825902 user_in_group = self.db_pool.simple_select_one_onecol_txn(
826903 txn,
921998 },
922999 )
9231000
924 def remove_user_from_summary(self, group_id, user_id, role_id):
1001 async def remove_user_from_summary(
1002 self, group_id: str, user_id: str, role_id: str
1003 ) -> int:
9251004 if role_id is None:
9261005 role_id = _DEFAULT_ROLE_ID
9271006
928 return self.db_pool.simple_delete(
1007 return await self.db_pool.simple_delete(
9291008 table="group_summary_users",
9301009 keyvalues={"group_id": group_id, "role_id": role_id, "user_id": user_id},
9311010 desc="remove_user_from_summary",
9321011 )
9331012
934 def add_group_invite(self, group_id, user_id):
1013 async def add_group_invite(self, group_id: str, user_id: str) -> None:
9351014 """Record that the group server has invited a user
9361015 """
937 return self.db_pool.simple_insert(
1016 await self.db_pool.simple_insert(
9381017 table="group_invites",
9391018 values={"group_id": group_id, "user_id": user_id},
9401019 desc="add_group_invite",
9411020 )
9421021
943 def add_user_to_group(
1022 async def add_user_to_group(
9441023 self,
945 group_id,
946 user_id,
947 is_admin=False,
948 is_public=True,
949 local_attestation=None,
950 remote_attestation=None,
951 ):
1024 group_id: str,
1025 user_id: str,
1026 is_admin: bool = False,
1027 is_public: bool = True,
1028 local_attestation: dict = None,
1029 remote_attestation: dict = None,
1030 ) -> None:
9521031 """Add a user to the group server.
9531032
9541033 Args:
955 group_id (str)
956 user_id (str)
957 is_admin (bool)
958 is_public (bool)
959 local_attestation (dict): The attestation the GS created to give
960 to the remote server. Optional if the user and group are on the
961 same server
962 remote_attestation (dict): The attestation given to GS by remote
1034 group_id
1035 user_id
1036 is_admin
1037 is_public
1038 local_attestation: The attestation the GS created to give to the remote
9631039 server. Optional if the user and group are on the same server
1040 remote_attestation: The attestation given to GS by remote server.
1041 Optional if the user and group are on the same server
9641042 """
9651043
9661044 def _add_user_to_group_txn(txn):
10031081 },
10041082 )
10051083
1006 return self.db_pool.runInteraction("add_user_to_group", _add_user_to_group_txn)
1007
1008 def remove_user_from_group(self, group_id, user_id):
1084 await self.db_pool.runInteraction("add_user_to_group", _add_user_to_group_txn)
1085
1086 async def remove_user_from_group(self, group_id: str, user_id: str) -> None:
10091087 def _remove_user_from_group_txn(txn):
10101088 self.db_pool.simple_delete_txn(
10111089 txn,
10331111 keyvalues={"group_id": group_id, "user_id": user_id},
10341112 )
10351113
1036 return self.db_pool.runInteraction(
1114 await self.db_pool.runInteraction(
10371115 "remove_user_from_group", _remove_user_from_group_txn
10381116 )
10391117
1040 def add_room_to_group(self, group_id, room_id, is_public):
1041 return self.db_pool.simple_insert(
1118 async def add_room_to_group(
1119 self, group_id: str, room_id: str, is_public: bool
1120 ) -> None:
1121 await self.db_pool.simple_insert(
10421122 table="group_rooms",
10431123 values={"group_id": group_id, "room_id": room_id, "is_public": is_public},
10441124 desc="add_room_to_group",
10451125 )
10461126
1047 def update_room_in_group_visibility(self, group_id, room_id, is_public):
1048 return self.db_pool.simple_update(
1127 async def update_room_in_group_visibility(
1128 self, group_id: str, room_id: str, is_public: bool
1129 ) -> int:
1130 return await self.db_pool.simple_update(
10491131 table="group_rooms",
10501132 keyvalues={"group_id": group_id, "room_id": room_id},
10511133 updatevalues={"is_public": is_public},
10521134 desc="update_room_in_group_visibility",
10531135 )
10541136
1055 def remove_room_from_group(self, group_id, room_id):
1137 async def remove_room_from_group(self, group_id: str, room_id: str) -> None:
10561138 def _remove_room_from_group_txn(txn):
10571139 self.db_pool.simple_delete_txn(
10581140 txn,
10661148 keyvalues={"group_id": group_id, "room_id": room_id},
10671149 )
10681150
1069 return self.db_pool.runInteraction(
1151 await self.db_pool.runInteraction(
10701152 "remove_room_from_group", _remove_room_from_group_txn
10711153 )
10721154
1073 def update_group_publicity(self, group_id, user_id, publicise):
1155 async def update_group_publicity(
1156 self, group_id: str, user_id: str, publicise: bool
1157 ) -> None:
10741158 """Update whether the user is publicising their membership of the group
10751159 """
1076 return self.db_pool.simple_update_one(
1160 await self.db_pool.simple_update_one(
10771161 table="local_group_membership",
10781162 keyvalues={"group_id": group_id, "user_id": user_id},
10791163 updatevalues={"is_publicised": publicise},
11801264
11811265 return next_id
11821266
1183 with self._group_updates_id_gen.get_next() as next_id:
1267 with await self._group_updates_id_gen.get_next() as next_id:
11841268 res = await self.db_pool.runInteraction(
11851269 "register_user_group_membership",
11861270 _register_user_group_membership_txn,
12121296 desc="update_group_profile",
12131297 )
12141298
1215 def update_attestation_renewal(self, group_id, user_id, attestation):
1299 async def update_attestation_renewal(
1300 self, group_id: str, user_id: str, attestation: dict
1301 ) -> None:
12161302 """Update an attestation that we have renewed
12171303 """
1218 return self.db_pool.simple_update_one(
1304 await self.db_pool.simple_update_one(
12191305 table="group_attestations_renewals",
12201306 keyvalues={"group_id": group_id, "user_id": user_id},
12211307 updatevalues={"valid_until_ms": attestation["valid_until_ms"]},
12221308 desc="update_attestation_renewal",
12231309 )
12241310
1225 def update_remote_attestion(self, group_id, user_id, attestation):
1311 async def update_remote_attestion(
1312 self, group_id: str, user_id: str, attestation: dict
1313 ) -> None:
12261314 """Update an attestation that a remote has renewed
12271315 """
1228 return self.db_pool.simple_update_one(
1316 await self.db_pool.simple_update_one(
12291317 table="group_attestations_remote",
12301318 keyvalues={"group_id": group_id, "user_id": user_id},
12311319 updatevalues={
12351323 desc="update_remote_attestion",
12361324 )
12371325
1238 def remove_attestation_renewal(self, group_id, user_id):
1326 async def remove_attestation_renewal(self, group_id: str, user_id: str) -> int:
12391327 """Remove an attestation that we thought we should renew, but actually
12401328 shouldn't. Ideally this would never get called as we would never
12411329 incorrectly try and do attestations for local users on local groups.
12421330
12431331 Args:
1244 group_id (str)
1245 user_id (str)
1246 """
1247 return self.db_pool.simple_delete(
1332 group_id
1333 user_id
1334 """
1335 return await self.db_pool.simple_delete(
12481336 table="group_attestations_renewals",
12491337 keyvalues={"group_id": group_id, "user_id": user_id},
12501338 desc="remove_attestation_renewal",
12531341 def get_group_stream_token(self):
12541342 return self._group_updates_id_gen.get_current_token()
12551343
1256 def delete_group(self, group_id):
1344 async def delete_group(self, group_id: str) -> None:
12571345 """Deletes a group fully from the database.
12581346
12591347 Args:
1260 group_id (str)
1261
1262 Returns:
1263 Deferred
1348 group_id: The group ID to delete.
12641349 """
12651350
12661351 def _delete_group_txn(txn):
12841369 txn, table=table, keyvalues={"group_id": group_id}
12851370 )
12861371
1287 return self.db_pool.runInteraction("delete_group", _delete_group_txn)
1372 await self.db_pool.runInteraction("delete_group", _delete_group_txn)
1515
1616 import itertools
1717 import logging
18 from typing import Dict, Iterable, List, Optional, Tuple
1819
1920 from signedjson.key import decode_verify_key_bytes
2021
4041 @cachedList(
4142 cached_method_name="_get_server_verify_key", list_name="server_name_and_key_ids"
4243 )
43 def get_server_verify_keys(self, server_name_and_key_ids):
44 """
45 Args:
46 server_name_and_key_ids (iterable[Tuple[str, str]]):
44 async def get_server_verify_keys(
45 self, server_name_and_key_ids: Iterable[Tuple[str, str]]
46 ) -> Dict[Tuple[str, str], Optional[FetchKeyResult]]:
47 """
48 Args:
49 server_name_and_key_ids:
4750 iterable of (server_name, key-id) tuples to fetch keys for
4851
4952 Returns:
50 Deferred: resolves to dict[Tuple[str, str], FetchKeyResult|None]:
51 map from (server_name, key_id) -> FetchKeyResult, or None if the key is
52 unknown
53 A map from (server_name, key_id) -> FetchKeyResult, or None if the
54 key is unknown
5355 """
5456 keys = {}
5557
8587 _get_keys(txn, batch)
8688 return keys
8789
88 return self.db_pool.runInteraction("get_server_verify_keys", _txn)
89
90 def store_server_verify_keys(self, from_server, ts_added_ms, verify_keys):
90 return await self.db_pool.runInteraction("get_server_verify_keys", _txn)
91
92 async def store_server_verify_keys(
93 self,
94 from_server: str,
95 ts_added_ms: int,
96 verify_keys: Iterable[Tuple[str, str, FetchKeyResult]],
97 ) -> None:
9198 """Stores NACL verification keys for remote servers.
9299 Args:
93 from_server (str): Where the verification keys were looked up
94 ts_added_ms (int): The time to record that the key was added
95 verify_keys (iterable[tuple[str, str, FetchKeyResult]]):
100 from_server: Where the verification keys were looked up
101 ts_added_ms: The time to record that the key was added
102 verify_keys:
96103 keys to be stored. Each entry is a triplet of
97104 (server_name, key_id, key).
98105 """
114121 # param, which is itself the 2-tuple (server_name, key_id).
115122 invalidations.append((server_name, key_id))
116123
117 def _invalidate(res):
118 f = self._get_server_verify_key.invalidate
119 for i in invalidations:
120 f((i,))
121 return res
122
123 return self.db_pool.runInteraction(
124 await self.db_pool.runInteraction(
124125 "store_server_verify_keys",
125126 self.db_pool.simple_upsert_many_txn,
126127 table="server_signature_keys",
133134 "verify_key",
134135 ),
135136 value_values=value_values,
136 ).addCallback(_invalidate)
137
138 def store_server_keys_json(
139 self, server_name, key_id, from_server, ts_now_ms, ts_expires_ms, key_json_bytes
140 ):
137 )
138
139 invalidate = self._get_server_verify_key.invalidate
140 for i in invalidations:
141 invalidate((i,))
142
143 async def store_server_keys_json(
144 self,
145 server_name: str,
146 key_id: str,
147 from_server: str,
148 ts_now_ms: int,
149 ts_expires_ms: int,
150 key_json_bytes: bytes,
151 ) -> None:
141152 """Stores the JSON bytes for a set of keys from a server
142153 The JSON should be signed by the originating server, the intermediate
143154 server, and by this server. Updates the value for the
144155 (server_name, key_id, from_server) triplet if one already existed.
145156 Args:
146 server_name (str): The name of the server.
147 key_id (str): The identifer of the key this JSON is for.
148 from_server (str): The server this JSON was fetched from.
149 ts_now_ms (int): The time now in milliseconds.
150 ts_valid_until_ms (int): The time when this json stops being valid.
151 key_json (bytes): The encoded JSON.
152 """
153 return self.db_pool.simple_upsert(
157 server_name: The name of the server.
158 key_id: The identifer of the key this JSON is for.
159 from_server: The server this JSON was fetched from.
160 ts_now_ms: The time now in milliseconds.
161 ts_valid_until_ms: The time when this json stops being valid.
162 key_json_bytes: The encoded JSON.
163 """
164 await self.db_pool.simple_upsert(
154165 table="server_keys_json",
155166 keyvalues={
156167 "server_name": server_name,
168179 desc="store_server_keys_json",
169180 )
170181
171 def get_server_keys_json(self, server_keys):
182 async def get_server_keys_json(
183 self, server_keys: Iterable[Tuple[str, Optional[str], Optional[str]]]
184 ) -> Dict[Tuple[str, Optional[str], Optional[str]], List[dict]]:
172185 """Retrive the key json for a list of server_keys and key ids.
173186 If no keys are found for a given server, key_id and source then
174187 that server, key_id, and source triplet entry will be an empty list.
177190 Args:
178191 server_keys (list): List of (server_name, key_id, source) triplets.
179192 Returns:
180 Deferred[dict[Tuple[str, str, str|None], list[dict]]]:
181 Dict mapping (server_name, key_id, source) triplets to lists of dicts
193 A mapping from (server_name, key_id, source) triplets to a list of dicts
182194 """
183195
184196 def _get_server_keys_json_txn(txn):
204216 results[(server_name, key_id, from_server)] = rows
205217 return results
206218
207 return self.db_pool.runInteraction(
219 return await self.db_pool.runInteraction(
208220 "get_server_keys_json", _get_server_keys_json_txn
209221 )
1111 # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
1212 # See the License for the specific language governing permissions and
1313 # limitations under the License.
14 from typing import Any, Dict, Iterable, List, Optional, Tuple
15
1416 from synapse.storage._base import SQLBaseStore
1517 from synapse.storage.database import DatabasePool
1618
3638 def __init__(self, database: DatabasePool, db_conn, hs):
3739 super(MediaRepositoryStore, self).__init__(database, db_conn, hs)
3840
39 def get_local_media(self, media_id):
41 async def get_local_media(self, media_id: str) -> Optional[Dict[str, Any]]:
4042 """Get the metadata for a local piece of media
43
4144 Returns:
4245 None if the media_id doesn't exist.
4346 """
44 return self.db_pool.simple_select_one(
47 return await self.db_pool.simple_select_one(
4548 "local_media_repository",
4649 {"media_id": media_id},
4750 (
5659 desc="get_local_media",
5760 )
5861
59 def store_local_media(
62 async def store_local_media(
6063 self,
6164 media_id,
6265 media_type,
6568 media_length,
6669 user_id,
6770 url_cache=None,
68 ):
69 return self.db_pool.simple_insert(
71 ) -> None:
72 await self.db_pool.simple_insert(
7073 "local_media_repository",
7174 {
7275 "media_id": media_id,
8083 desc="store_local_media",
8184 )
8285
83 def mark_local_media_as_safe(self, media_id: str):
86 async def mark_local_media_as_safe(self, media_id: str) -> None:
8487 """Mark a local media as safe from quarantining."""
85 return self.db_pool.simple_update_one(
88 await self.db_pool.simple_update_one(
8689 table="local_media_repository",
8790 keyvalues={"media_id": media_id},
8891 updatevalues={"safe_from_quarantine": True},
8992 desc="mark_local_media_as_safe",
9093 )
9194
92 def get_url_cache(self, url, ts):
95 async def get_url_cache(self, url: str, ts: int) -> Optional[Dict[str, Any]]:
9396 """Get the media_id and ts for a cached URL as of the given timestamp
9497 Returns:
9598 None if the URL isn't cached.
135138 )
136139 )
137140
138 return self.db_pool.runInteraction("get_url_cache", get_url_cache_txn)
139
140 def store_url_cache(
141 return await self.db_pool.runInteraction("get_url_cache", get_url_cache_txn)
142
143 async def store_url_cache(
141144 self, url, response_code, etag, expires_ts, og, media_id, download_ts
142145 ):
143 return self.db_pool.simple_insert(
146 await self.db_pool.simple_insert(
144147 "local_media_repository_url_cache",
145148 {
146149 "url": url,
154157 desc="store_url_cache",
155158 )
156159
157 def get_local_media_thumbnails(self, media_id):
158 return self.db_pool.simple_select_list(
160 async def get_local_media_thumbnails(self, media_id: str) -> List[Dict[str, Any]]:
161 return await self.db_pool.simple_select_list(
159162 "local_media_repository_thumbnails",
160163 {"media_id": media_id},
161164 (
168171 desc="get_local_media_thumbnails",
169172 )
170173
171 def store_local_thumbnail(
174 async def store_local_thumbnail(
172175 self,
173176 media_id,
174177 thumbnail_width,
177180 thumbnail_method,
178181 thumbnail_length,
179182 ):
180 return self.db_pool.simple_insert(
183 await self.db_pool.simple_insert(
181184 "local_media_repository_thumbnails",
182185 {
183186 "media_id": media_id,
190193 desc="store_local_thumbnail",
191194 )
192195
193 def get_cached_remote_media(self, origin, media_id):
194 return self.db_pool.simple_select_one(
196 async def get_cached_remote_media(
197 self, origin, media_id: str
198 ) -> Optional[Dict[str, Any]]:
199 return await self.db_pool.simple_select_one(
195200 "remote_media_cache",
196201 {"media_origin": origin, "media_id": media_id},
197202 (
206211 desc="get_cached_remote_media",
207212 )
208213
209 def store_cached_remote_media(
214 async def store_cached_remote_media(
210215 self,
211216 origin,
212217 media_id,
216221 upload_name,
217222 filesystem_id,
218223 ):
219 return self.db_pool.simple_insert(
224 await self.db_pool.simple_insert(
220225 "remote_media_cache",
221226 {
222227 "media_origin": origin,
231236 desc="store_cached_remote_media",
232237 )
233238
234 def update_cached_last_access_time(self, local_media, remote_media, time_ms):
239 async def update_cached_last_access_time(
240 self,
241 local_media: Iterable[str],
242 remote_media: Iterable[Tuple[str, str]],
243 time_ms: int,
244 ):
235245 """Updates the last access time of the given media
236246
237247 Args:
238 local_media (iterable[str]): Set of media_ids
239 remote_media (iterable[(str, str)]): Set of (server_name, media_id)
248 local_media: Set of media_ids
249 remote_media: Set of (server_name, media_id)
240250 time_ms: Current time in milliseconds
241251 """
242252
261271
262272 txn.executemany(sql, ((time_ms, media_id) for media_id in local_media))
263273
264 return self.db_pool.runInteraction(
274 return await self.db_pool.runInteraction(
265275 "update_cached_last_access_time", update_cache_txn
266276 )
267277
268 def get_remote_media_thumbnails(self, origin, media_id):
269 return self.db_pool.simple_select_list(
278 async def get_remote_media_thumbnails(
279 self, origin: str, media_id: str
280 ) -> List[Dict[str, Any]]:
281 return await self.db_pool.simple_select_list(
270282 "remote_media_cache_thumbnails",
271283 {"media_origin": origin, "media_id": media_id},
272284 (
280292 desc="get_remote_media_thumbnails",
281293 )
282294
283 def store_remote_media_thumbnail(
295 async def store_remote_media_thumbnail(
284296 self,
285297 origin,
286298 media_id,
291303 thumbnail_method,
292304 thumbnail_length,
293305 ):
294 return self.db_pool.simple_insert(
306 await self.db_pool.simple_insert(
295307 "remote_media_cache_thumbnails",
296308 {
297309 "media_origin": origin,
306318 desc="store_remote_media_thumbnail",
307319 )
308320
309 def get_remote_media_before(self, before_ts):
321 async def get_remote_media_before(self, before_ts):
310322 sql = (
311323 "SELECT media_origin, media_id, filesystem_id"
312324 " FROM remote_media_cache"
313325 " WHERE last_access_ts < ?"
314326 )
315327
316 return self.db_pool.execute(
328 return await self.db_pool.execute(
317329 "get_remote_media_before", self.db_pool.cursor_to_dict, sql, before_ts
318330 )
319331
320 def delete_remote_media(self, media_origin, media_id):
332 async def delete_remote_media(self, media_origin: str, media_id: str) -> None:
321333 def delete_remote_media_txn(txn):
322334 self.db_pool.simple_delete_txn(
323335 txn,
330342 keyvalues={"media_origin": media_origin, "media_id": media_id},
331343 )
332344
333 return self.db_pool.runInteraction(
345 await self.db_pool.runInteraction(
334346 "delete_remote_media", delete_remote_media_txn
335347 )
336348
337 def get_expired_url_cache(self, now_ts):
349 async def get_expired_url_cache(self, now_ts: int) -> List[str]:
338350 sql = (
339351 "SELECT media_id FROM local_media_repository_url_cache"
340352 " WHERE expires_ts < ?"
346358 txn.execute(sql, (now_ts,))
347359 return [row[0] for row in txn]
348360
349 return self.db_pool.runInteraction(
361 return await self.db_pool.runInteraction(
350362 "get_expired_url_cache", _get_expired_url_cache_txn
351363 )
352364
363375 "delete_url_cache", _delete_url_cache_txn
364376 )
365377
366 def get_url_cache_media_before(self, before_ts):
378 async def get_url_cache_media_before(self, before_ts: int) -> List[str]:
367379 sql = (
368380 "SELECT media_id FROM local_media_repository"
369381 " WHERE created_ts < ? AND url_cache IS NOT NULL"
375387 txn.execute(sql, (before_ts,))
376388 return [row[0] for row in txn]
377389
378 return self.db_pool.runInteraction(
390 return await self.db_pool.runInteraction(
379391 "get_url_cache_media_before", _get_url_cache_media_before_txn
380392 )
381393
1212 # See the License for the specific language governing permissions and
1313 # limitations under the License.
1414 import logging
15 from typing import List
15 from typing import Dict, List
1616
1717 from synapse.storage._base import SQLBaseStore
1818 from synapse.storage.database import DatabasePool, make_in_list_sql_clause
3232 self.hs = hs
3333
3434 @cached(num_args=0)
35 def get_monthly_active_count(self):
35 async def get_monthly_active_count(self) -> int:
3636 """Generates current count of monthly active users
3737
3838 Returns:
39 Defered[int]: Number of current monthly active users
39 Number of current monthly active users
4040 """
4141
4242 def _count_users(txn):
4545 (count,) = txn.fetchone()
4646 return count
4747
48 return self.db_pool.runInteraction("count_users", _count_users)
48 return await self.db_pool.runInteraction("count_users", _count_users)
4949
5050 @cached(num_args=0)
51 def get_monthly_active_count_by_service(self):
51 async def get_monthly_active_count_by_service(self) -> Dict[str, int]:
5252 """Generates current count of monthly active users broken down by service.
5353 A service is typically an appservice but also includes native matrix users.
5454 Since the `monthly_active_users` table is populated from the `user_ips` table
5656 method to return anything other than native matrix users.
5757
5858 Returns:
59 Deferred[dict]: dict that includes a mapping between app_service_id
60 and the number of occurrences.
59 A mapping between app_service_id and the number of occurrences.
6160
6261 """
6362
7372 result = txn.fetchall()
7473 return dict(result)
7574
76 return self.db_pool.runInteraction(
75 return await self.db_pool.runInteraction(
7776 "count_users_by_service", _count_users_by_service
7877 )
7978
9897 return users
9998
10099 @cached(num_args=1)
101 def user_last_seen_monthly_active(self, user_id):
102 """
103 Checks if a given user is part of the monthly active user group
104 Arguments:
105 user_id (str): user to add/update
106 Return:
107 Deferred[int] : timestamp since last seen, None if never seen
108
109 """
110
111 return self.db_pool.simple_select_one_onecol(
100 async def user_last_seen_monthly_active(self, user_id: str) -> int:
101 """
102 Checks if a given user is part of the monthly active user group
103
104 Arguments:
105 user_id: user to add/update
106
107 Return:
108 Timestamp since last seen, None if never seen
109 """
110
111 return await self.db_pool.simple_select_one_onecol(
112112 table="monthly_active_users",
113113 keyvalues={"user_id": user_id},
114114 retcol="timestamp",
0 from typing import Optional
1
02 from synapse.storage._base import SQLBaseStore
13
24
35 class OpenIdStore(SQLBaseStore):
4 def insert_open_id_token(self, token, ts_valid_until_ms, user_id):
5 return self.db_pool.simple_insert(
6 async def insert_open_id_token(
7 self, token: str, ts_valid_until_ms: int, user_id: str
8 ) -> None:
9 await self.db_pool.simple_insert(
610 table="open_id_tokens",
711 values={
812 "token": token,
1216 desc="insert_open_id_token",
1317 )
1418
15 def get_user_id_for_open_id_token(self, token, ts_now_ms):
19 async def get_user_id_for_open_id_token(
20 self, token: str, ts_now_ms: int
21 ) -> Optional[str]:
1622 def get_user_id_for_token_txn(txn):
1723 sql = (
1824 "SELECT user_id FROM open_id_tokens"
2733 else:
2834 return rows[0][0]
2935
30 return self.db_pool.runInteraction(
36 return await self.db_pool.runInteraction(
3137 "get_user_id_for_token", get_user_id_for_token_txn
3238 )
1414
1515 from typing import List, Tuple
1616
17 from synapse.api.presence import UserPresenceState
1718 from synapse.storage._base import SQLBaseStore, make_in_list_sql_clause
18 from synapse.storage.presence import UserPresenceState
1919 from synapse.util.caches.descriptors import cached, cachedList
2020 from synapse.util.iterutils import batch_iter
2121
2222
2323 class PresenceStore(SQLBaseStore):
2424 async def update_presence(self, presence_states):
25 stream_ordering_manager = self._presence_id_gen.get_next_mult(
25 stream_ordering_manager = await self._presence_id_gen.get_next_mult(
2626 len(presence_states)
2727 )
2828
129129 raise NotImplementedError()
130130
131131 @cachedList(
132 cached_method_name="_get_presence_for_user",
133 list_name="user_ids",
134 num_args=1,
135 inlineCallbacks=True,
132 cached_method_name="_get_presence_for_user", list_name="user_ids", num_args=1,
136133 )
137 def get_presence_for_users(self, user_ids):
138 rows = yield self.db_pool.simple_select_many_batch(
134 async def get_presence_for_users(self, user_ids):
135 rows = await self.db_pool.simple_select_many_batch(
139136 table="presence_stream",
140137 column="user_id",
141138 iterable=user_ids,
159156
160157 def get_current_presence_token(self):
161158 return self._presence_id_gen.get_current_token()
162
163 def allow_presence_visible(self, observed_localpart, observer_userid):
164 return self.db_pool.simple_insert(
165 table="presence_allow_inbound",
166 values={
167 "observed_user_id": observed_localpart,
168 "observer_user_id": observer_userid,
169 },
170 desc="allow_presence_visible",
171 or_ignore=True,
172 )
173
174 def disallow_presence_visible(self, observed_localpart, observer_userid):
175 return self.db_pool.simple_delete_one(
176 table="presence_allow_inbound",
177 keyvalues={
178 "observed_user_id": observed_localpart,
179 "observer_user_id": observer_userid,
180 },
181 desc="disallow_presence_visible",
182 )
1111 # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
1212 # See the License for the specific language governing permissions and
1313 # limitations under the License.
14 from typing import Any, Dict, Optional
1415
1516 from synapse.api.errors import StoreError
1617 from synapse.storage._base import SQLBaseStore
1819
1920
2021 class ProfileWorkerStore(SQLBaseStore):
21 async def get_profileinfo(self, user_localpart):
22 async def get_profileinfo(self, user_localpart: str) -> ProfileInfo:
2223 try:
2324 profile = await self.db_pool.simple_select_one(
2425 table="profiles",
3738 avatar_url=profile["avatar_url"], display_name=profile["displayname"]
3839 )
3940
40 def get_profile_displayname(self, user_localpart):
41 return self.db_pool.simple_select_one_onecol(
41 async def get_profile_displayname(self, user_localpart: str) -> str:
42 return await self.db_pool.simple_select_one_onecol(
4243 table="profiles",
4344 keyvalues={"user_id": user_localpart},
4445 retcol="displayname",
4546 desc="get_profile_displayname",
4647 )
4748
48 def get_profile_avatar_url(self, user_localpart):
49 return self.db_pool.simple_select_one_onecol(
49 async def get_profile_avatar_url(self, user_localpart: str) -> str:
50 return await self.db_pool.simple_select_one_onecol(
5051 table="profiles",
5152 keyvalues={"user_id": user_localpart},
5253 retcol="avatar_url",
5354 desc="get_profile_avatar_url",
5455 )
5556
56 def get_from_remote_profile_cache(self, user_id):
57 return self.db_pool.simple_select_one(
57 async def get_from_remote_profile_cache(
58 self, user_id: str
59 ) -> Optional[Dict[str, Any]]:
60 return await self.db_pool.simple_select_one(
5861 table="remote_profile_cache",
5962 keyvalues={"user_id": user_id},
6063 retcols=("displayname", "avatar_url"),
6265 desc="get_from_remote_profile_cache",
6366 )
6467
65 def create_profile(self, user_localpart):
66 return self.db_pool.simple_insert(
68 async def create_profile(self, user_localpart: str) -> None:
69 await self.db_pool.simple_insert(
6770 table="profiles", values={"user_id": user_localpart}, desc="create_profile"
6871 )
6972
70 def set_profile_displayname(self, user_localpart, new_displayname):
71 return self.db_pool.simple_update_one(
73 async def set_profile_displayname(
74 self, user_localpart: str, new_displayname: str
75 ) -> None:
76 await self.db_pool.simple_update_one(
7277 table="profiles",
7378 keyvalues={"user_id": user_localpart},
7479 updatevalues={"displayname": new_displayname},
7580 desc="set_profile_displayname",
7681 )
7782
78 def set_profile_avatar_url(self, user_localpart, new_avatar_url):
79 return self.db_pool.simple_update_one(
83 async def set_profile_avatar_url(
84 self, user_localpart: str, new_avatar_url: str
85 ) -> None:
86 await self.db_pool.simple_update_one(
8087 table="profiles",
8188 keyvalues={"user_id": user_localpart},
8289 updatevalues={"avatar_url": new_avatar_url},
8592
8693
8794 class ProfileStore(ProfileWorkerStore):
88 def add_remote_profile_cache(self, user_id, displayname, avatar_url):
95 async def add_remote_profile_cache(
96 self, user_id: str, displayname: str, avatar_url: str
97 ) -> None:
8998 """Ensure we are caching the remote user's profiles.
9099
91100 This should only be called when `is_subscribed_remote_profile_for_user`
92101 would return true for the user.
93102 """
94 return self.db_pool.simple_upsert(
103 await self.db_pool.simple_upsert(
95104 table="remote_profile_cache",
96105 keyvalues={"user_id": user_id},
97106 values={
102111 desc="add_remote_profile_cache",
103112 )
104113
105 def update_remote_profile_cache(self, user_id, displayname, avatar_url):
106 return self.db_pool.simple_update(
114 async def update_remote_profile_cache(
115 self, user_id: str, displayname: str, avatar_url: str
116 ) -> int:
117 return await self.db_pool.simple_update(
107118 table="remote_profile_cache",
108119 keyvalues={"user_id": user_id},
109120 updatevalues={
126137 desc="delete_remote_profile_cache",
127138 )
128139
129 def get_remote_profile_cache_entries_that_expire(self, last_checked):
140 async def get_remote_profile_cache_entries_that_expire(
141 self, last_checked: int
142 ) -> Dict[str, str]:
130143 """Get all users who haven't been checked since `last_checked`
131144 """
132145
141154
142155 return self.db_pool.cursor_to_dict(txn)
143156
144 return self.db_pool.runInteraction(
157 return await self.db_pool.runInteraction(
145158 "get_remote_profile_cache_entries_that_expire",
146159 _get_remote_profile_cache_entries_that_expire_txn,
147160 )
1313 # limitations under the License.
1414
1515 import logging
16 from typing import Any, Tuple
16 from typing import Any, List, Set, Tuple
1717
1818 from synapse.api.errors import SynapseError
1919 from synapse.storage._base import SQLBaseStore
2424
2525
2626 class PurgeEventsStore(StateGroupWorkerStore, SQLBaseStore):
27 def purge_history(self, room_id, token, delete_local_events):
27 async def purge_history(
28 self, room_id: str, token: str, delete_local_events: bool
29 ) -> Set[int]:
2830 """Deletes room history before a certain point
2931
3032 Args:
31 room_id (str):
32
33 token (str): A topological token to delete events before
34
35 delete_local_events (bool):
33 room_id:
34 token: A topological token to delete events before
35 delete_local_events:
3636 if True, we will delete local events as well as remote ones
3737 (instead of just marking them as outliers and deleting their
3838 state groups).
3939
4040 Returns:
41 Deferred[set[int]]: The set of state groups that are referenced by
42 deleted events.
41 The set of state groups that are referenced by deleted events.
4342 """
4443
45 return self.db_pool.runInteraction(
44 return await self.db_pool.runInteraction(
4645 "purge_history",
4746 self._purge_history_txn,
4847 room_id,
282281
283282 return referenced_state_groups
284283
285 def purge_room(self, room_id):
284 async def purge_room(self, room_id: str) -> List[int]:
286285 """Deletes all record of a room
287286
288287 Args:
289 room_id (str)
288 room_id
290289
291290 Returns:
292 Deferred[List[int]]: The list of state groups to delete.
291 The list of state groups to delete.
293292 """
294
295 return self.db_pool.runInteraction("purge_room", self._purge_room_txn, room_id)
293 return await self.db_pool.runInteraction(
294 "purge_room", self._purge_room_txn, room_id
295 )
296296
297297 def _purge_room_txn(self, txn, room_id):
298298 # First we fetch all the state groups that should be deleted, before
1717 import logging
1818 from typing import List, Tuple, Union
1919
20 from twisted.internet import defer
21
2220 from synapse.push.baserules import list_with_base_rules
2321 from synapse.replication.slave.storage._slaved_id_tracker import SlavedIdTracker
2422 from synapse.storage._base import SQLBaseStore, db_to_json
2927 from synapse.storage.databases.main.receipts import ReceiptsWorkerStore
3028 from synapse.storage.databases.main.roommember import RoomMemberWorkerStore
3129 from synapse.storage.push_rule import InconsistentRuleException, RuleNotFoundException
32 from synapse.storage.util.id_generators import ChainedIdGenerator
30 from synapse.storage.util.id_generators import StreamIdGenerator
3331 from synapse.util import json_encoder
34 from synapse.util.caches.descriptors import cachedInlineCallbacks, cachedList
32 from synapse.util.caches.descriptors import cached, cachedList
3533 from synapse.util.caches.stream_change_cache import StreamChangeCache
3634
3735 logger = logging.getLogger(__name__)
8179 super(PushRulesWorkerStore, self).__init__(database, db_conn, hs)
8280
8381 if hs.config.worker.worker_app is None:
84 self._push_rules_stream_id_gen = ChainedIdGenerator(
85 self._stream_id_gen, db_conn, "push_rules_stream", "stream_id"
86 ) # type: Union[ChainedIdGenerator, SlavedIdTracker]
82 self._push_rules_stream_id_gen = StreamIdGenerator(
83 db_conn, "push_rules_stream", "stream_id"
84 ) # type: Union[StreamIdGenerator, SlavedIdTracker]
8785 else:
8886 self._push_rules_stream_id_gen = SlavedIdTracker(
8987 db_conn, "push_rules_stream", "stream_id"
114112 """
115113 raise NotImplementedError()
116114
117 @cachedInlineCallbacks(max_entries=5000)
118 def get_push_rules_for_user(self, user_id):
119 rows = yield self.db_pool.simple_select_list(
115 @cached(max_entries=5000)
116 async def get_push_rules_for_user(self, user_id):
117 rows = await self.db_pool.simple_select_list(
120118 table="push_rules",
121119 keyvalues={"user_name": user_id},
122120 retcols=(
132130
133131 rows.sort(key=lambda row: (-int(row["priority_class"]), -int(row["priority"])))
134132
135 enabled_map = yield self.get_push_rules_enabled_for_user(user_id)
133 enabled_map = await self.get_push_rules_enabled_for_user(user_id)
136134
137135 use_new_defaults = user_id in self._users_new_default_push_rules
138136
139 rules = _load_rules(rows, enabled_map, use_new_defaults)
140
141 return rules
142
143 @cachedInlineCallbacks(max_entries=5000)
144 def get_push_rules_enabled_for_user(self, user_id):
145 results = yield self.db_pool.simple_select_list(
137 return _load_rules(rows, enabled_map, use_new_defaults)
138
139 @cached(max_entries=5000)
140 async def get_push_rules_enabled_for_user(self, user_id):
141 results = await self.db_pool.simple_select_list(
146142 table="push_rules_enable",
147143 keyvalues={"user_name": user_id},
148144 retcols=("user_name", "rule_id", "enabled"),
150146 )
151147 return {r["rule_id"]: False if r["enabled"] == 0 else True for r in results}
152148
153 def have_push_rules_changed_for_user(self, user_id, last_id):
149 async def have_push_rules_changed_for_user(
150 self, user_id: str, last_id: int
151 ) -> bool:
154152 if not self.push_rules_stream_cache.has_entity_changed(user_id, last_id):
155 return defer.succeed(False)
153 return False
156154 else:
157155
158156 def have_push_rules_changed_txn(txn):
164162 (count,) = txn.fetchone()
165163 return bool(count)
166164
167 return self.db_pool.runInteraction(
165 return await self.db_pool.runInteraction(
168166 "have_push_rules_changed", have_push_rules_changed_txn
169167 )
170168
171169 @cachedList(
172 cached_method_name="get_push_rules_for_user",
173 list_name="user_ids",
174 num_args=1,
175 inlineCallbacks=True,
170 cached_method_name="get_push_rules_for_user", list_name="user_ids", num_args=1,
176171 )
177 def bulk_get_push_rules(self, user_ids):
172 async def bulk_get_push_rules(self, user_ids):
178173 if not user_ids:
179174 return {}
180175
181176 results = {user_id: [] for user_id in user_ids}
182177
183 rows = yield self.db_pool.simple_select_many_batch(
178 rows = await self.db_pool.simple_select_many_batch(
184179 table="push_rules",
185180 column="user_name",
186181 iterable=user_ids,
193188 for row in rows:
194189 results.setdefault(row["user_name"], []).append(row)
195190
196 enabled_map_by_user = yield self.bulk_get_push_rules_enabled(user_ids)
191 enabled_map_by_user = await self.bulk_get_push_rules_enabled(user_ids)
197192
198193 for user_id, rules in results.items():
199194 use_new_defaults = user_id in self._users_new_default_push_rules
204199
205200 return results
206201
207 @defer.inlineCallbacks
208 def copy_push_rule_from_room_to_room(self, new_room_id, user_id, rule):
202 async def copy_push_rule_from_room_to_room(
203 self, new_room_id: str, user_id: str, rule: dict
204 ) -> None:
209205 """Copy a single push rule from one room to another for a specific user.
210206
211207 Args:
212 new_room_id (str): ID of the new room.
213 user_id (str): ID of user the push rule belongs to.
214 rule (Dict): A push rule.
208 new_room_id: ID of the new room.
209 user_id : ID of user the push rule belongs to.
210 rule: A push rule.
215211 """
216212 # Create new rule id
217213 rule_id_scope = "/".join(rule["rule_id"].split("/")[:-1])
223219 condition["pattern"] = new_room_id
224220
225221 # Add the rule for the new room
226 yield self.add_push_rule(
222 await self.add_push_rule(
227223 user_id=user_id,
228224 rule_id=new_rule_id,
229225 priority_class=rule["priority_class"],
231227 actions=rule["actions"],
232228 )
233229
234 @defer.inlineCallbacks
235 def copy_push_rules_from_room_to_room_for_user(
236 self, old_room_id, new_room_id, user_id
237 ):
230 async def copy_push_rules_from_room_to_room_for_user(
231 self, old_room_id: str, new_room_id: str, user_id: str
232 ) -> None:
238233 """Copy all of the push rules from one room to another for a specific
239234 user.
240235
241236 Args:
242 old_room_id (str): ID of the old room.
243 new_room_id (str): ID of the new room.
244 user_id (str): ID of user to copy push rules for.
237 old_room_id: ID of the old room.
238 new_room_id: ID of the new room.
239 user_id: ID of user to copy push rules for.
245240 """
246241 # Retrieve push rules for this user
247 user_push_rules = yield self.get_push_rules_for_user(user_id)
242 user_push_rules = await self.get_push_rules_for_user(user_id)
248243
249244 # Get rules relating to the old room and copy them to the new room
250245 for rule in user_push_rules:
253248 (c.get("key") == "room_id" and c.get("pattern") == old_room_id)
254249 for c in conditions
255250 ):
256 yield self.copy_push_rule_from_room_to_room(new_room_id, user_id, rule)
251 await self.copy_push_rule_from_room_to_room(new_room_id, user_id, rule)
257252
258253 @cachedList(
259254 cached_method_name="get_push_rules_enabled_for_user",
260255 list_name="user_ids",
261256 num_args=1,
262 inlineCallbacks=True,
263257 )
264 def bulk_get_push_rules_enabled(self, user_ids):
258 async def bulk_get_push_rules_enabled(self, user_ids):
265259 if not user_ids:
266260 return {}
267261
268262 results = {user_id: {} for user_id in user_ids}
269263
270 rows = yield self.db_pool.simple_select_many_batch(
264 rows = await self.db_pool.simple_select_many_batch(
271265 table="push_rules_enable",
272266 column="user_name",
273267 iterable=user_ids,
331325
332326
333327 class PushRuleStore(PushRulesWorkerStore):
334 @defer.inlineCallbacks
335 def add_push_rule(
328 async def add_push_rule(
336329 self,
337330 user_id,
338331 rule_id,
341334 actions,
342335 before=None,
343336 after=None,
344 ):
337 ) -> None:
345338 conditions_json = json_encoder.encode(conditions)
346339 actions_json = json_encoder.encode(actions)
347 with self._push_rules_stream_id_gen.get_next() as ids:
348 stream_id, event_stream_ordering = ids
340 with await self._push_rules_stream_id_gen.get_next() as stream_id:
341 event_stream_ordering = self._stream_id_gen.get_current_token()
342
349343 if before or after:
350 yield self.db_pool.runInteraction(
344 await self.db_pool.runInteraction(
351345 "_add_push_rule_relative_txn",
352346 self._add_push_rule_relative_txn,
353347 stream_id,
361355 after,
362356 )
363357 else:
364 yield self.db_pool.runInteraction(
358 await self.db_pool.runInteraction(
365359 "_add_push_rule_highest_priority_txn",
366360 self._add_push_rule_highest_priority_txn,
367361 stream_id,
545539 },
546540 )
547541
548 @defer.inlineCallbacks
549 def delete_push_rule(self, user_id, rule_id):
542 async def delete_push_rule(self, user_id: str, rule_id: str) -> None:
550543 """
551544 Delete a push rule. Args specify the row to be deleted and can be
552545 any of the columns in the push_rule table, but below are the
553546 standard ones
554547
555548 Args:
556 user_id (str): The matrix ID of the push rule owner
557 rule_id (str): The rule_id of the rule to be deleted
549 user_id: The matrix ID of the push rule owner
550 rule_id: The rule_id of the rule to be deleted
558551 """
559552
560553 def delete_push_rule_txn(txn, stream_id, event_stream_ordering):
566559 txn, stream_id, event_stream_ordering, user_id, rule_id, op="DELETE"
567560 )
568561
569 with self._push_rules_stream_id_gen.get_next() as ids:
570 stream_id, event_stream_ordering = ids
571 yield self.db_pool.runInteraction(
562 with await self._push_rules_stream_id_gen.get_next() as stream_id:
563 event_stream_ordering = self._stream_id_gen.get_current_token()
564
565 await self.db_pool.runInteraction(
572566 "delete_push_rule",
573567 delete_push_rule_txn,
574568 stream_id,
575569 event_stream_ordering,
576570 )
577571
578 @defer.inlineCallbacks
579 def set_push_rule_enabled(self, user_id, rule_id, enabled):
580 with self._push_rules_stream_id_gen.get_next() as ids:
581 stream_id, event_stream_ordering = ids
582 yield self.db_pool.runInteraction(
572 async def set_push_rule_enabled(self, user_id, rule_id, enabled) -> None:
573 with await self._push_rules_stream_id_gen.get_next() as stream_id:
574 event_stream_ordering = self._stream_id_gen.get_current_token()
575
576 await self.db_pool.runInteraction(
583577 "_set_push_rule_enabled_txn",
584578 self._set_push_rule_enabled_txn,
585579 stream_id,
610604 op="ENABLE" if enabled else "DISABLE",
611605 )
612606
613 @defer.inlineCallbacks
614 def set_push_rule_actions(self, user_id, rule_id, actions, is_default_rule):
607 async def set_push_rule_actions(
608 self, user_id, rule_id, actions, is_default_rule
609 ) -> None:
615610 actions_json = json_encoder.encode(actions)
616611
617612 def set_push_rule_actions_txn(txn, stream_id, event_stream_ordering):
650645 data={"actions": actions_json},
651646 )
652647
653 with self._push_rules_stream_id_gen.get_next() as ids:
654 stream_id, event_stream_ordering = ids
655 yield self.db_pool.runInteraction(
648 with await self._push_rules_stream_id_gen.get_next() as stream_id:
649 event_stream_ordering = self._stream_id_gen.get_current_token()
650
651 await self.db_pool.runInteraction(
656652 "set_push_rule_actions",
657653 set_push_rule_actions_txn,
658654 stream_id,
680676 self.push_rules_stream_cache.entity_has_changed, user_id, stream_id
681677 )
682678
683 def get_push_rules_stream_token(self):
684 """Get the position of the push rules stream.
685 Returns a pair of a stream id for the push_rules stream and the
686 room stream ordering it corresponds to."""
679 def get_max_push_rules_stream_id(self):
687680 return self._push_rules_stream_id_gen.get_current_token()
688
689 def get_max_push_rules_stream_id(self):
690 return self.get_push_rules_stream_token()[0]
1818
1919 from canonicaljson import encode_canonical_json
2020
21 from twisted.internet import defer
22
2321 from synapse.storage._base import SQLBaseStore, db_to_json
24 from synapse.util.caches.descriptors import cachedInlineCallbacks, cachedList
22 from synapse.util.caches.descriptors import cached, cachedList
2523
2624 logger = logging.getLogger(__name__)
2725
3331 Drops any rows whose data cannot be decoded
3432 """
3533 for r in rows:
36 dataJson = r["data"]
34 data_json = r["data"]
3735 try:
38 r["data"] = db_to_json(dataJson)
36 r["data"] = db_to_json(data_json)
3937 except Exception as e:
4038 logger.warning(
4139 "Invalid JSON in data for pusher %d: %s, %s",
4240 r["id"],
43 dataJson,
41 data_json,
4442 e.args[0],
4543 )
4644 continue
4745
4846 yield r
4947
50 @defer.inlineCallbacks
51 def user_has_pusher(self, user_id):
52 ret = yield self.db_pool.simple_select_one_onecol(
48 async def user_has_pusher(self, user_id):
49 ret = await self.db_pool.simple_select_one_onecol(
5350 "pushers", {"user_name": user_id}, "id", allow_none=True
5451 )
5552 return ret is not None
6057 def get_pushers_by_user_id(self, user_id):
6158 return self.get_pushers_by({"user_name": user_id})
6259
63 @defer.inlineCallbacks
64 def get_pushers_by(self, keyvalues):
65 ret = yield self.db_pool.simple_select_list(
60 async def get_pushers_by(self, keyvalues):
61 ret = await self.db_pool.simple_select_list(
6662 "pushers",
6763 keyvalues,
6864 [
8682 )
8783 return self._decode_pushers_rows(ret)
8884
89 @defer.inlineCallbacks
90 def get_all_pushers(self):
85 async def get_all_pushers(self):
9186 def get_pushers(txn):
9287 txn.execute("SELECT * FROM pushers")
9388 rows = self.db_pool.cursor_to_dict(txn)
9489
9590 return self._decode_pushers_rows(rows)
9691
97 rows = yield self.db_pool.runInteraction("get_all_pushers", get_pushers)
98 return rows
92 return await self.db_pool.runInteraction("get_all_pushers", get_pushers)
9993
10094 async def get_all_updated_pushers_rows(
10195 self, instance_name: str, last_id: int, current_id: int, limit: int
163157 "get_all_updated_pushers_rows", get_all_updated_pushers_rows_txn
164158 )
165159
166 @cachedInlineCallbacks(num_args=1, max_entries=15000)
167 def get_if_user_has_pusher(self, user_id):
160 @cached(num_args=1, max_entries=15000)
161 async def get_if_user_has_pusher(self, user_id):
168162 # This only exists for the cachedList decorator
169163 raise NotImplementedError()
170164
171165 @cachedList(
172 cached_method_name="get_if_user_has_pusher",
173 list_name="user_ids",
174 num_args=1,
175 inlineCallbacks=True,
166 cached_method_name="get_if_user_has_pusher", list_name="user_ids", num_args=1,
176167 )
177 def get_if_users_have_pushers(self, user_ids):
178 rows = yield self.db_pool.simple_select_many_batch(
168 async def get_if_users_have_pushers(self, user_ids):
169 rows = await self.db_pool.simple_select_many_batch(
179170 table="pushers",
180171 column="user_name",
181172 iterable=user_ids,
188179
189180 return result
190181
191 @defer.inlineCallbacks
192 def update_pusher_last_stream_ordering(
182 async def update_pusher_last_stream_ordering(
193183 self, app_id, pushkey, user_id, last_stream_ordering
194 ):
195 yield self.db_pool.simple_update_one(
184 ) -> None:
185 await self.db_pool.simple_update_one(
196186 "pushers",
197187 {"app_id": app_id, "pushkey": pushkey, "user_name": user_id},
198188 {"last_stream_ordering": last_stream_ordering},
199189 desc="update_pusher_last_stream_ordering",
200190 )
201191
202 @defer.inlineCallbacks
203 def update_pusher_last_stream_ordering_and_success(
204 self, app_id, pushkey, user_id, last_stream_ordering, last_success
205 ):
192 async def update_pusher_last_stream_ordering_and_success(
193 self,
194 app_id: str,
195 pushkey: str,
196 user_id: str,
197 last_stream_ordering: int,
198 last_success: int,
199 ) -> bool:
206200 """Update the last stream ordering position we've processed up to for
207201 the given pusher.
208202
209203 Args:
210 app_id (str)
211 pushkey (str)
212 last_stream_ordering (int)
213 last_success (int)
204 app_id
205 pushkey
206 user_id
207 last_stream_ordering
208 last_success
214209
215210 Returns:
216 Deferred[bool]: True if the pusher still exists; False if it has been deleted.
211 True if the pusher still exists; False if it has been deleted.
217212 """
218 updated = yield self.db_pool.simple_update(
213 updated = await self.db_pool.simple_update(
219214 table="pushers",
220215 keyvalues={"app_id": app_id, "pushkey": pushkey, "user_name": user_id},
221216 updatevalues={
227222
228223 return bool(updated)
229224
230 @defer.inlineCallbacks
231 def update_pusher_failing_since(self, app_id, pushkey, user_id, failing_since):
232 yield self.db_pool.simple_update(
225 async def update_pusher_failing_since(
226 self, app_id, pushkey, user_id, failing_since
227 ) -> None:
228 await self.db_pool.simple_update(
233229 table="pushers",
234230 keyvalues={"app_id": app_id, "pushkey": pushkey, "user_name": user_id},
235231 updatevalues={"failing_since": failing_since},
236232 desc="update_pusher_failing_since",
237233 )
238234
239 @defer.inlineCallbacks
240 def get_throttle_params_by_room(self, pusher_id):
241 res = yield self.db_pool.simple_select_list(
235 async def get_throttle_params_by_room(self, pusher_id):
236 res = await self.db_pool.simple_select_list(
242237 "pusher_throttle",
243238 {"pusher": pusher_id},
244239 ["room_id", "last_sent_ts", "throttle_ms"],
254249
255250 return params_by_room
256251
257 @defer.inlineCallbacks
258 def set_throttle_params(self, pusher_id, room_id, params):
252 async def set_throttle_params(self, pusher_id, room_id, params) -> None:
259253 # no need to lock because `pusher_throttle` has a primary key on
260254 # (pusher, room_id) so simple_upsert will retry
261 yield self.db_pool.simple_upsert(
255 await self.db_pool.simple_upsert(
262256 "pusher_throttle",
263257 {"pusher": pusher_id, "room_id": room_id},
264258 params,
271265 def get_pushers_stream_token(self):
272266 return self._pushers_id_gen.get_current_token()
273267
274 @defer.inlineCallbacks
275 def add_pusher(
268 async def add_pusher(
276269 self,
277270 user_id,
278271 access_token,
286279 data,
287280 last_stream_ordering,
288281 profile_tag="",
289 ):
290 with self._pushers_id_gen.get_next() as stream_id:
282 ) -> None:
283 with await self._pushers_id_gen.get_next() as stream_id:
291284 # no need to lock because `pushers` has a unique key on
292285 # (app_id, pushkey, user_name) so simple_upsert will retry
293 yield self.db_pool.simple_upsert(
286 await self.db_pool.simple_upsert(
294287 table="pushers",
295288 keyvalues={"app_id": app_id, "pushkey": pushkey, "user_name": user_id},
296289 values={
315308
316309 if user_has_pusher is not True:
317310 # invalidate, since we the user might not have had a pusher before
318 yield self.db_pool.runInteraction(
311 await self.db_pool.runInteraction(
319312 "add_pusher",
320313 self._invalidate_cache_and_stream,
321314 self.get_if_user_has_pusher,
322315 (user_id,),
323316 )
324317
325 @defer.inlineCallbacks
326 def delete_pusher_by_app_id_pushkey_user_id(self, app_id, pushkey, user_id):
318 async def delete_pusher_by_app_id_pushkey_user_id(
319 self, app_id, pushkey, user_id
320 ) -> None:
327321 def delete_pusher_txn(txn, stream_id):
328322 self._invalidate_cache_and_stream(
329323 txn, self.get_if_user_has_pusher, (user_id,)
349343 },
350344 )
351345
352 with self._pushers_id_gen.get_next() as stream_id:
353 yield self.db_pool.runInteraction(
346 with await self._pushers_id_gen.get_next() as stream_id:
347 await self.db_pool.runInteraction(
354348 "delete_pusher", delete_pusher_txn, stream_id
355349 )
1515
1616 import abc
1717 import logging
18 from typing import List, Tuple
18 from typing import Any, Dict, List, Optional, Tuple
1919
2020 from twisted.internet import defer
2121
2424 from synapse.storage.util.id_generators import StreamIdGenerator
2525 from synapse.util import json_encoder
2626 from synapse.util.async_helpers import ObservableDeferred
27 from synapse.util.caches.descriptors import cached, cachedInlineCallbacks, cachedList
27 from synapse.util.caches.descriptors import cached, cachedList
2828 from synapse.util.caches.stream_change_cache import StreamChangeCache
2929
3030 logger = logging.getLogger(__name__)
5555 """
5656 raise NotImplementedError()
5757
58 @cachedInlineCallbacks()
59 def get_users_with_read_receipts_in_room(self, room_id):
60 receipts = yield self.get_receipts_for_room(room_id, "m.read")
58 @cached()
59 async def get_users_with_read_receipts_in_room(self, room_id):
60 receipts = await self.get_receipts_for_room(room_id, "m.read")
6161 return {r["user_id"] for r in receipts}
6262
6363 @cached(num_args=2)
64 def get_receipts_for_room(self, room_id, receipt_type):
65 return self.db_pool.simple_select_list(
64 async def get_receipts_for_room(
65 self, room_id: str, receipt_type: str
66 ) -> List[Dict[str, Any]]:
67 return await self.db_pool.simple_select_list(
6668 table="receipts_linearized",
6769 keyvalues={"room_id": room_id, "receipt_type": receipt_type},
6870 retcols=("user_id", "event_id"),
7072 )
7173
7274 @cached(num_args=3)
73 def get_last_receipt_event_id_for_user(self, user_id, room_id, receipt_type):
74 return self.db_pool.simple_select_one_onecol(
75 async def get_last_receipt_event_id_for_user(
76 self, user_id: str, room_id: str, receipt_type: str
77 ) -> Optional[str]:
78 return await self.db_pool.simple_select_one_onecol(
7579 table="receipts_linearized",
7680 keyvalues={
7781 "room_id": room_id,
8387 allow_none=True,
8488 )
8589
86 @cachedInlineCallbacks(num_args=2)
87 def get_receipts_for_user(self, user_id, receipt_type):
88 rows = yield self.db_pool.simple_select_list(
90 @cached(num_args=2)
91 async def get_receipts_for_user(self, user_id, receipt_type):
92 rows = await self.db_pool.simple_select_list(
8993 table="receipts_linearized",
9094 keyvalues={"user_id": user_id, "receipt_type": receipt_type},
9195 retcols=("room_id", "event_id"),
9498
9599 return {row["room_id"]: row["event_id"] for row in rows}
96100
97 @defer.inlineCallbacks
98 def get_receipts_for_user_with_orderings(self, user_id, receipt_type):
101 async def get_receipts_for_user_with_orderings(self, user_id, receipt_type):
99102 def f(txn):
100103 sql = (
101104 "SELECT rl.room_id, rl.event_id,"
109112 txn.execute(sql, (user_id,))
110113 return txn.fetchall()
111114
112 rows = yield self.db_pool.runInteraction(
115 rows = await self.db_pool.runInteraction(
113116 "get_receipts_for_user_with_orderings", f
114117 )
115118 return {
121124 for row in rows
122125 }
123126
124 @defer.inlineCallbacks
125 def get_linearized_receipts_for_rooms(self, room_ids, to_key, from_key=None):
127 async def get_linearized_receipts_for_rooms(
128 self, room_ids: List[str], to_key: int, from_key: Optional[int] = None
129 ) -> List[dict]:
126130 """Get receipts for multiple rooms for sending to clients.
127131
128132 Args:
129 room_ids (list): List of room_ids.
130 to_key (int): Max stream id to fetch receipts upto.
131 from_key (int): Min stream id to fetch receipts from. None fetches
133 room_id: List of room_ids.
134 to_key: Max stream id to fetch receipts upto.
135 from_key: Min stream id to fetch receipts from. None fetches
132136 from the start.
133137
134138 Returns:
135 list: A list of receipts.
139 A list of receipts.
136140 """
137141 room_ids = set(room_ids)
138142
139143 if from_key is not None:
140144 # Only ask the database about rooms where there have been new
141145 # receipts added since `from_key`
142 room_ids = yield self._receipts_stream_cache.get_entities_changed(
146 room_ids = self._receipts_stream_cache.get_entities_changed(
143147 room_ids, from_key
144148 )
145149
146 results = yield self._get_linearized_receipts_for_rooms(
150 results = await self._get_linearized_receipts_for_rooms(
147151 room_ids, to_key, from_key=from_key
148152 )
149153
150154 return [ev for res in results.values() for ev in res]
151155
152 def get_linearized_receipts_for_room(self, room_id, to_key, from_key=None):
156 async def get_linearized_receipts_for_room(
157 self, room_id: str, to_key: int, from_key: Optional[int] = None
158 ) -> List[dict]:
153159 """Get receipts for a single room for sending to clients.
154160
155161 Args:
156 room_ids (str): The room id.
157 to_key (int): Max stream id to fetch receipts upto.
158 from_key (int): Min stream id to fetch receipts from. None fetches
162 room_ids: The room id.
163 to_key: Max stream id to fetch receipts upto.
164 from_key: Min stream id to fetch receipts from. None fetches
159165 from the start.
160166
161167 Returns:
162 Deferred[list]: A list of receipts.
168 A list of receipts.
163169 """
164170 if from_key is not None:
165171 # Check the cache first to see if any new receipts have been added
166172 # since`from_key`. If not we can no-op.
167173 if not self._receipts_stream_cache.has_entity_changed(room_id, from_key):
168 defer.succeed([])
169
170 return self._get_linearized_receipts_for_room(room_id, to_key, from_key)
171
172 @cachedInlineCallbacks(num_args=3, tree=True)
173 def _get_linearized_receipts_for_room(self, room_id, to_key, from_key=None):
174 return []
175
176 return await self._get_linearized_receipts_for_room(room_id, to_key, from_key)
177
178 @cached(num_args=3, tree=True)
179 async def _get_linearized_receipts_for_room(
180 self, room_id: str, to_key: int, from_key: Optional[int] = None
181 ) -> List[dict]:
174182 """See get_linearized_receipts_for_room
175183 """
176184
194202
195203 return rows
196204
197 rows = yield self.db_pool.runInteraction("get_linearized_receipts_for_room", f)
205 rows = await self.db_pool.runInteraction("get_linearized_receipts_for_room", f)
198206
199207 if not rows:
200208 return []
211219 cached_method_name="_get_linearized_receipts_for_room",
212220 list_name="room_ids",
213221 num_args=3,
214 inlineCallbacks=True,
215222 )
216 def _get_linearized_receipts_for_rooms(self, room_ids, to_key, from_key=None):
223 async def _get_linearized_receipts_for_rooms(self, room_ids, to_key, from_key=None):
217224 if not room_ids:
218225 return {}
219226
242249
243250 return self.db_pool.cursor_to_dict(txn)
244251
245 txn_results = yield self.db_pool.runInteraction(
252 txn_results = await self.db_pool.runInteraction(
246253 "_get_linearized_receipts_for_rooms", f
247254 )
248255
268275 }
269276 return results
270277
271 def get_users_sent_receipts_between(self, last_id: int, current_id: int):
278 async def get_users_sent_receipts_between(
279 self, last_id: int, current_id: int
280 ) -> List[str]:
272281 """Get all users who sent receipts between `last_id` exclusive and
273282 `current_id` inclusive.
274283
275284 Returns:
276 Deferred[List[str]]
285 The list of users.
277286 """
278287
279288 if last_id == current_id:
288297
289298 return [r[0] for r in txn]
290299
291 return self.db_pool.runInteraction(
300 return await self.db_pool.runInteraction(
292301 "get_users_sent_receipts_between", _get_users_sent_receipts_between_txn
293302 )
294303
345354 )
346355
347356 def _invalidate_get_users_with_receipts_in_room(
348 self, room_id, receipt_type, user_id
357 self, room_id: str, receipt_type: str, user_id: str
349358 ):
350359 if receipt_type != "m.read":
351360 return
471480
472481 return rx_ts
473482
474 @defer.inlineCallbacks
475 def insert_receipt(self, room_id, receipt_type, user_id, event_ids, data):
483 async def insert_receipt(
484 self,
485 room_id: str,
486 receipt_type: str,
487 user_id: str,
488 event_ids: List[str],
489 data: dict,
490 ) -> Optional[Tuple[int, int]]:
476491 """Insert a receipt, either from local client or remote server.
477492
478493 Automatically does conversion between linearized and graph
479494 representations.
480495 """
481496 if not event_ids:
482 return
497 return None
483498
484499 if len(event_ids) == 1:
485500 linearized_event_id = event_ids[0]
506521 else:
507522 raise RuntimeError("Unrecognized event_ids: %r" % (event_ids,))
508523
509 linearized_event_id = yield self.db_pool.runInteraction(
524 linearized_event_id = await self.db_pool.runInteraction(
510525 "insert_receipt_conv", graph_to_linear
511526 )
512527
513 stream_id_manager = self._receipts_id_gen.get_next()
514 with stream_id_manager as stream_id:
515 event_ts = yield self.db_pool.runInteraction(
528 with await self._receipts_id_gen.get_next() as stream_id:
529 event_ts = await self.db_pool.runInteraction(
516530 "insert_linearized_receipt",
517531 self.insert_linearized_receipt_txn,
518532 room_id,
534548 now - event_ts,
535549 )
536550
537 yield self.insert_graph_receipt(room_id, receipt_type, user_id, event_ids, data)
551 await self.insert_graph_receipt(room_id, receipt_type, user_id, event_ids, data)
538552
539553 max_persisted_id = self._receipts_id_gen.get_current_token()
540554
541555 return stream_id, max_persisted_id
542556
543 def insert_graph_receipt(self, room_id, receipt_type, user_id, event_ids, data):
544 return self.db_pool.runInteraction(
557 async def insert_graph_receipt(
558 self, room_id, receipt_type, user_id, event_ids, data
559 ):
560 return await self.db_pool.runInteraction(
545561 "insert_graph_receipt",
546562 self.insert_graph_receipt_txn,
547563 room_id,
1616
1717 import logging
1818 import re
19 from typing import Dict, List, Optional
20
21 from twisted.internet.defer import Deferred
19 from typing import Any, Dict, List, Optional, Tuple
2220
2321 from synapse.api.constants import UserTypes
2422 from synapse.api.errors import Codes, StoreError, SynapseError, ThreepidValidationError
4745 )
4846
4947 @cached()
50 def get_user_by_id(self, user_id):
51 return self.db_pool.simple_select_one(
48 async def get_user_by_id(self, user_id: str) -> Optional[Dict[str, Any]]:
49 return await self.db_pool.simple_select_one(
5250 table="users",
5351 keyvalues={"name": user_id},
5452 retcols=[
8583 return is_trial
8684
8785 @cached()
88 def get_user_by_access_token(self, token):
86 async def get_user_by_access_token(self, token: str) -> Optional[dict]:
8987 """Get a user from the given access token.
9088
9189 Args:
92 token (str): The access token of a user.
93 Returns:
94 defer.Deferred: None, if the token did not match, otherwise dict
95 including the keys `name`, `is_guest`, `device_id`, `token_id`,
96 `valid_until_ms`.
97 """
98 return self.db_pool.runInteraction(
90 token: The access token of a user.
91 Returns:
92 None, if the token did not match, otherwise dict
93 including the keys `name`, `is_guest`, `device_id`, `token_id`,
94 `valid_until_ms`.
95 """
96 return await self.db_pool.runInteraction(
9997 "get_user_by_access_token", self._query_for_auth, token
10098 )
10199
102100 @cached()
103 async def get_expiration_ts_for_user(self, user_id: str) -> Optional[None]:
101 async def get_expiration_ts_for_user(self, user_id: str) -> Optional[int]:
104102 """Get the expiration timestamp for the account bearing a given user ID.
105103
106104 Args:
282280
283281 return bool(res) if res else False
284282
285 def set_server_admin(self, user, admin):
283 async def set_server_admin(self, user: UserID, admin: bool) -> None:
286284 """Sets whether a user is an admin of this homeserver.
287285
288286 Args:
289 user (UserID): user ID of the user to test
290 admin (bool): true iff the user is to be a server admin,
291 false otherwise.
287 user: user ID of the user to test
288 admin: true iff the user is to be a server admin, false otherwise.
292289 """
293290
294291 def set_server_admin_txn(txn):
299296 txn, self.get_user_by_id, (user.to_string(),)
300297 )
301298
302 return self.db_pool.runInteraction("set_server_admin", set_server_admin_txn)
299 await self.db_pool.runInteraction("set_server_admin", set_server_admin_txn)
303300
304301 def _query_for_auth(self, txn, token):
305302 sql = (
306 "SELECT users.name, users.is_guest, access_tokens.id as token_id,"
303 "SELECT users.name, users.is_guest, users.shadow_banned, access_tokens.id as token_id,"
307304 " access_tokens.device_id, access_tokens.valid_until_ms"
308305 " FROM users"
309306 " INNER JOIN access_tokens on users.name = access_tokens.user_id"
365362 )
366363 return True if res == UserTypes.SUPPORT else False
367364
368 def get_users_by_id_case_insensitive(self, user_id):
365 async def get_users_by_id_case_insensitive(self, user_id: str) -> Dict[str, str]:
369366 """Gets users that match user_id case insensitively.
370 Returns a mapping of user_id -> password_hash.
367
368 Returns:
369 A mapping of user_id -> password_hash.
371370 """
372371
373372 def f(txn):
375374 txn.execute(sql, (user_id,))
376375 return dict(txn)
377376
378 return self.db_pool.runInteraction("get_users_by_id_case_insensitive", f)
377 return await self.db_pool.runInteraction("get_users_by_id_case_insensitive", f)
379378
380379 async def get_user_by_external_id(
381380 self, auth_provider: str, external_id: str
409408
410409 return await self.db_pool.runInteraction("count_users", _count_users)
411410
412 def count_daily_user_type(self):
411 async def count_daily_user_type(self) -> Dict[str, int]:
413412 """
414413 Counts 1) native non guest users
415414 2) native guests users
438437 results[row[0]] = row[1]
439438 return results
440439
441 return self.db_pool.runInteraction(
440 return await self.db_pool.runInteraction(
442441 "count_daily_user_type", _count_daily_user_type
443442 )
444443
530529 "user_get_threepids",
531530 )
532531
533 def user_delete_threepid(self, user_id, medium, address):
534 return self.db_pool.simple_delete(
532 async def user_delete_threepid(self, user_id, medium, address) -> None:
533 await self.db_pool.simple_delete(
535534 "user_threepids",
536535 keyvalues={"user_id": user_id, "medium": medium, "address": address},
537536 desc="user_delete_threepid",
538537 )
539538
540 def user_delete_threepids(self, user_id: str):
539 async def user_delete_threepids(self, user_id: str) -> None:
541540 """Delete all threepid this user has bound
542541
543542 Args:
544543 user_id: The user id to delete all threepids of
545544
546545 """
547 return self.db_pool.simple_delete(
546 await self.db_pool.simple_delete(
548547 "user_threepids",
549548 keyvalues={"user_id": user_id},
550549 desc="user_delete_threepids",
551550 )
552551
553 def add_user_bound_threepid(self, user_id, medium, address, id_server):
552 async def add_user_bound_threepid(
553 self, user_id: str, medium: str, address: str, id_server: str
554 ):
554555 """The server proxied a bind request to the given identity server on
555556 behalf of the given user. We need to remember this in case the user
556557 asks us to unbind the threepid.
557558
558559 Args:
559 user_id (str)
560 medium (str)
561 address (str)
562 id_server (str)
563
564 Returns:
565 Deferred
560 user_id
561 medium
562 address
563 id_server
566564 """
567565 # We need to use an upsert, in case they user had already bound the
568566 # threepid
569 return self.db_pool.simple_upsert(
567 await self.db_pool.simple_upsert(
570568 table="user_threepid_id_server",
571569 keyvalues={
572570 "user_id": user_id,
579577 desc="add_user_bound_threepid",
580578 )
581579
582 def user_get_bound_threepids(self, user_id):
580 async def user_get_bound_threepids(self, user_id: str) -> List[Dict[str, Any]]:
583581 """Get the threepids that a user has bound to an identity server through the homeserver
584582 The homeserver remembers where binds to an identity server occurred. Using this
585583 method can retrieve those threepids.
586584
587585 Args:
588 user_id (str): The ID of the user to retrieve threepids for
589
590 Returns:
591 Deferred[list[dict]]: List of dictionaries containing the following:
586 user_id: The ID of the user to retrieve threepids for
587
588 Returns:
589 List of dictionaries containing the following keys:
592590 medium (str): The medium of the threepid (e.g "email")
593591 address (str): The address of the threepid (e.g "bob@example.com")
594592 """
595 return self.db_pool.simple_select_list(
593 return await self.db_pool.simple_select_list(
596594 table="user_threepid_id_server",
597595 keyvalues={"user_id": user_id},
598596 retcols=["medium", "address"],
599597 desc="user_get_bound_threepids",
600598 )
601599
602 def remove_user_bound_threepid(self, user_id, medium, address, id_server):
600 async def remove_user_bound_threepid(
601 self, user_id: str, medium: str, address: str, id_server: str
602 ) -> None:
603603 """The server proxied an unbind request to the given identity server on
604604 behalf of the given user, so we remove the mapping of threepid to
605605 identity server.
606606
607607 Args:
608 user_id (str)
609 medium (str)
610 address (str)
611 id_server (str)
612
613 Returns:
614 Deferred
615 """
616 return self.db_pool.simple_delete(
608 user_id
609 medium
610 address
611 id_server
612 """
613 await self.db_pool.simple_delete(
617614 table="user_threepid_id_server",
618615 keyvalues={
619616 "user_id": user_id,
624621 desc="remove_user_bound_threepid",
625622 )
626623
627 def get_id_servers_user_bound(self, user_id, medium, address):
624 async def get_id_servers_user_bound(
625 self, user_id: str, medium: str, address: str
626 ) -> List[str]:
628627 """Get the list of identity servers that the server proxied bind
629628 requests to for given user and threepid
630629
631630 Args:
632 user_id (str)
633 medium (str)
634 address (str)
635
636 Returns:
637 Deferred[list[str]]: Resolves to a list of identity servers
638 """
639 return self.db_pool.simple_select_onecol(
631 user_id: The user to query for identity servers.
632 medium: The medium to query for identity servers.
633 address: The address to query for identity servers.
634
635 Returns:
636 A list of identity servers
637 """
638 return await self.db_pool.simple_select_onecol(
640639 table="user_threepid_id_server",
641640 keyvalues={"user_id": user_id, "medium": medium, "address": address},
642641 retcol="id_server",
664663 # Convert the integer into a boolean.
665664 return res == 1
666665
667 def get_threepid_validation_session(
668 self, medium, client_secret, address=None, sid=None, validated=True
669 ):
666 async def get_threepid_validation_session(
667 self,
668 medium: Optional[str],
669 client_secret: str,
670 address: Optional[str] = None,
671 sid: Optional[str] = None,
672 validated: Optional[bool] = True,
673 ) -> Optional[Dict[str, Any]]:
670674 """Gets a session_id and last_send_attempt (if available) for a
671675 combination of validation metadata
672676
673677 Args:
674 medium (str|None): The medium of the 3PID
675 address (str|None): The address of the 3PID
676 sid (str|None): The ID of the validation session
677 client_secret (str): A unique string provided by the client to help identify this
678 medium: The medium of the 3PID
679 client_secret: A unique string provided by the client to help identify this
678680 validation attempt
679 validated (bool|None): Whether sessions should be filtered by
681 address: The address of the 3PID
682 sid: The ID of the validation session
683 validated: Whether sessions should be filtered by
680684 whether they have been validated already or not. None to
681685 perform no filtering
682686
683687 Returns:
684 Deferred[dict|None]: A dict containing the following:
688 A dict containing the following:
685689 * address - address of the 3pid
686690 * medium - medium of the 3pid
687691 * client_secret - a secret provided by the client for this validation session
727731
728732 return rows[0]
729733
730 return self.db_pool.runInteraction(
734 return await self.db_pool.runInteraction(
731735 "get_threepid_validation_session", get_threepid_validation_session_txn
732736 )
733737
734 def delete_threepid_session(self, session_id):
738 async def delete_threepid_session(self, session_id: str) -> None:
735739 """Removes a threepid validation session from the database. This can
736740 be done after validation has been performed and whatever action was
737741 waiting on it has been carried out
738742
739743 Args:
740 session_id (str): The ID of the session to delete
744 session_id: The ID of the session to delete
741745 """
742746
743747 def delete_threepid_session_txn(txn):
752756 keyvalues={"session_id": session_id},
753757 )
754758
755 return self.db_pool.runInteraction(
759 await self.db_pool.runInteraction(
756760 "delete_threepid_session", delete_threepid_session_txn
757761 )
758762
890894 super(RegistrationStore, self).__init__(database, db_conn, hs)
891895
892896 self._account_validity = hs.config.account_validity
897 self._ignore_unknown_session_error = hs.config.request_token_inhibit_3pid_errors
893898
894899 if self._account_validity.enabled:
895900 self._clock.call_later(
941946 desc="add_access_token_to_user",
942947 )
943948
944 def register_user(
949 async def register_user(
945950 self,
946 user_id,
947 password_hash=None,
948 was_guest=False,
949 make_guest=False,
950 appservice_id=None,
951 create_profile_with_displayname=None,
952 admin=False,
953 user_type=None,
954 ):
951 user_id: str,
952 password_hash: Optional[str] = None,
953 was_guest: bool = False,
954 make_guest: bool = False,
955 appservice_id: Optional[str] = None,
956 create_profile_with_displayname: Optional[str] = None,
957 admin: bool = False,
958 user_type: Optional[str] = None,
959 shadow_banned: bool = False,
960 ) -> None:
955961 """Attempts to register an account.
956962
957963 Args:
958 user_id (str): The desired user ID to register.
959 password_hash (str|None): Optional. The password hash for this user.
960 was_guest (bool): Optional. Whether this is a guest account being
961 upgraded to a non-guest account.
962 make_guest (boolean): True if the the new user should be guest,
963 false to add a regular user account.
964 appservice_id (str): The ID of the appservice registering the user.
965 create_profile_with_displayname (unicode): Optionally create a profile for
964 user_id: The desired user ID to register.
965 password_hash: Optional. The password hash for this user.
966 was_guest: Whether this is a guest account being upgraded to a
967 non-guest account.
968 make_guest: True if the the new user should be guest, false to add a
969 regular user account.
970 appservice_id: The ID of the appservice registering the user.
971 create_profile_with_displayname: Optionally create a profile for
966972 the user, setting their displayname to the given value
967 admin (boolean): is an admin user?
968 user_type (str|None): type of user. One of the values from
969 api.constants.UserTypes, or None for a normal user.
973 admin: is an admin user?
974 user_type: type of user. One of the values from api.constants.UserTypes,
975 or None for a normal user.
976 shadow_banned: Whether the user is shadow-banned, i.e. they may be
977 told their requests succeeded but we ignore them.
970978
971979 Raises:
972980 StoreError if the user_id could not be registered.
973
974 Returns:
975 Deferred
976 """
977 return self.db_pool.runInteraction(
981 """
982 await self.db_pool.runInteraction(
978983 "register_user",
979984 self._register_user,
980985 user_id,
985990 create_profile_with_displayname,
986991 admin,
987992 user_type,
993 shadow_banned,
988994 )
989995
990996 def _register_user(
9981004 create_profile_with_displayname,
9991005 admin,
10001006 user_type,
1007 shadow_banned,
10011008 ):
10021009 user_id_obj = UserID.from_string(user_id)
10031010
10271034 "appservice_id": appservice_id,
10281035 "admin": 1 if admin else 0,
10291036 "user_type": user_type,
1037 "shadow_banned": shadow_banned,
10301038 },
10311039 )
10321040 else:
10411049 "appservice_id": appservice_id,
10421050 "admin": 1 if admin else 0,
10431051 "user_type": user_type,
1052 "shadow_banned": shadow_banned,
10441053 },
10451054 )
10461055
10741083
10751084 self._invalidate_cache_and_stream(txn, self.get_user_by_id, (user_id,))
10761085
1077 def record_user_external_id(
1086 async def record_user_external_id(
10781087 self, auth_provider: str, external_id: str, user_id: str
1079 ) -> Deferred:
1088 ) -> None:
10801089 """Record a mapping from an external user id to a mxid
10811090
10821091 Args:
10841093 external_id: id on that system
10851094 user_id: complete mxid that it is mapped to
10861095 """
1087 return self.db_pool.simple_insert(
1096 await self.db_pool.simple_insert(
10881097 table="user_external_ids",
10891098 values={
10901099 "auth_provider": auth_provider,
10941103 desc="record_user_external_id",
10951104 )
10961105
1097 def user_set_password_hash(self, user_id, password_hash):
1106 async def user_set_password_hash(self, user_id: str, password_hash: str) -> None:
10981107 """
10991108 NB. This does *not* evict any cache because the one use for this
11001109 removes most of the entries subsequently anyway so it would be
11071116 )
11081117 self._invalidate_cache_and_stream(txn, self.get_user_by_id, (user_id,))
11091118
1110 return self.db_pool.runInteraction(
1119 await self.db_pool.runInteraction(
11111120 "user_set_password_hash", user_set_password_hash_txn
11121121 )
11131122
1114 def user_set_consent_version(self, user_id, consent_version):
1123 async def user_set_consent_version(
1124 self, user_id: str, consent_version: str
1125 ) -> None:
11151126 """Updates the user table to record privacy policy consent
11161127
11171128 Args:
1118 user_id (str): full mxid of the user to update
1119 consent_version (str): version of the policy the user has consented
1120 to
1129 user_id: full mxid of the user to update
1130 consent_version: version of the policy the user has consented to
11211131
11221132 Raises:
11231133 StoreError(404) if user not found
11321142 )
11331143 self._invalidate_cache_and_stream(txn, self.get_user_by_id, (user_id,))
11341144
1135 return self.db_pool.runInteraction("user_set_consent_version", f)
1136
1137 def user_set_consent_server_notice_sent(self, user_id, consent_version):
1145 await self.db_pool.runInteraction("user_set_consent_version", f)
1146
1147 async def user_set_consent_server_notice_sent(
1148 self, user_id: str, consent_version: str
1149 ) -> None:
11381150 """Updates the user table to record that we have sent the user a server
11391151 notice about privacy policy consent
11401152
11411153 Args:
1142 user_id (str): full mxid of the user to update
1143 consent_version (str): version of the policy we have notified the
1144 user about
1154 user_id: full mxid of the user to update
1155 consent_version: version of the policy we have notified the user about
11451156
11461157 Raises:
11471158 StoreError(404) if user not found
11561167 )
11571168 self._invalidate_cache_and_stream(txn, self.get_user_by_id, (user_id,))
11581169
1159 return self.db_pool.runInteraction("user_set_consent_server_notice_sent", f)
1160
1161 def user_delete_access_tokens(self, user_id, except_token_id=None, device_id=None):
1170 await self.db_pool.runInteraction("user_set_consent_server_notice_sent", f)
1171
1172 async def user_delete_access_tokens(
1173 self,
1174 user_id: str,
1175 except_token_id: Optional[str] = None,
1176 device_id: Optional[str] = None,
1177 ) -> List[Tuple[str, int, Optional[str]]]:
11621178 """
11631179 Invalidate access tokens belonging to a user
11641180
11651181 Args:
1166 user_id (str): ID of user the tokens belong to
1167 except_token_id (str): list of access_tokens IDs which should
1168 *not* be deleted
1169 device_id (str|None): ID of device the tokens are associated with.
1182 user_id: ID of user the tokens belong to
1183 except_token_id: access_tokens ID which should *not* be deleted
1184 device_id: ID of device the tokens are associated with.
11701185 If None, tokens associated with any device (or no device) will
11711186 be deleted
11721187 Returns:
1173 defer.Deferred[list[str, int, str|None, int]]: a list of
1174 (token, token id, device id) for each of the deleted tokens
1188 A tuple of (token, token id, device id) for each of the deleted tokens
11751189 """
11761190
11771191 def f(txn):
12021216
12031217 return tokens_and_devices
12041218
1205 return self.db_pool.runInteraction("user_delete_access_tokens", f)
1206
1207 def delete_access_token(self, access_token):
1219 return await self.db_pool.runInteraction("user_delete_access_tokens", f)
1220
1221 async def delete_access_token(self, access_token: str) -> None:
12081222 def f(txn):
12091223 self.db_pool.simple_delete_one_txn(
12101224 txn, table="access_tokens", keyvalues={"token": access_token}
12141228 txn, self.get_user_by_access_token, (access_token,)
12151229 )
12161230
1217 return self.db_pool.runInteraction("delete_access_token", f)
1231 await self.db_pool.runInteraction("delete_access_token", f)
12181232
12191233 @cached()
12201234 async def is_guest(self, user_id: str) -> bool:
12281242
12291243 return res if res else False
12301244
1231 def add_user_pending_deactivation(self, user_id):
1245 async def add_user_pending_deactivation(self, user_id: str) -> None:
12321246 """
12331247 Adds a user to the table of users who need to be parted from all the rooms they're
12341248 in
12351249 """
1236 return self.db_pool.simple_insert(
1250 await self.db_pool.simple_insert(
12371251 "users_pending_deactivation",
12381252 values={"user_id": user_id},
12391253 desc="add_user_pending_deactivation",
12401254 )
12411255
1242 def del_user_pending_deactivation(self, user_id):
1256 async def del_user_pending_deactivation(self, user_id: str) -> None:
12431257 """
12441258 Removes the given user to the table of users who need to be parted from all the
12451259 rooms they're in, effectively marking that user as fully deactivated.
12461260 """
12471261 # XXX: This should be simple_delete_one but we failed to put a unique index on
12481262 # the table, so somehow duplicate entries have ended up in it.
1249 return self.db_pool.simple_delete(
1263 await self.db_pool.simple_delete(
12501264 "users_pending_deactivation",
12511265 keyvalues={"user_id": user_id},
12521266 desc="del_user_pending_deactivation",
12531267 )
12541268
1255 def get_user_pending_deactivation(self):
1269 async def get_user_pending_deactivation(self) -> Optional[str]:
12561270 """
12571271 Gets one user from the table of users waiting to be parted from all the rooms
12581272 they're in.
12591273 """
1260 return self.db_pool.simple_select_one_onecol(
1274 return await self.db_pool.simple_select_one_onecol(
12611275 "users_pending_deactivation",
12621276 keyvalues={},
12631277 retcol="user_id",
12651279 desc="get_users_pending_deactivation",
12661280 )
12671281
1268 def validate_threepid_session(self, session_id, client_secret, token, current_ts):
1282 async def validate_threepid_session(
1283 self, session_id: str, client_secret: str, token: str, current_ts: int
1284 ) -> Optional[str]:
12691285 """Attempt to validate a threepid session using a token
12701286
12711287 Args:
1272 session_id (str): The id of a validation session
1273 client_secret (str): A unique string provided by the client to
1274 help identify this validation attempt
1275 token (str): A validation token
1276 current_ts (int): The current unix time in milliseconds. Used for
1277 checking token expiry status
1288 session_id: The id of a validation session
1289 client_secret: A unique string provided by the client to help identify
1290 this validation attempt
1291 token: A validation token
1292 current_ts: The current unix time in milliseconds. Used for checking
1293 token expiry status
12781294
12791295 Raises:
12801296 ThreepidValidationError: if a matching validation token was not found or has
12811297 expired
12821298
12831299 Returns:
1284 deferred str|None: A str representing a link to redirect the user
1285 to if there is one.
1300 A str representing a link to redirect the user to if there is one.
12861301 """
12871302
12881303 # Insert everything into a transaction in order to run atomically
12961311 )
12971312
12981313 if not row:
1299 raise ThreepidValidationError(400, "Unknown session_id")
1314 if self._ignore_unknown_session_error:
1315 # If we need to inhibit the error caused by an incorrect session ID,
1316 # use None as placeholder values for the client secret and the
1317 # validation timestamp.
1318 # It shouldn't be an issue because they're both only checked after
1319 # the token check, which should fail. And if it doesn't for some
1320 # reason, the next check is on the client secret, which is NOT NULL,
1321 # so we don't have to worry about the client secret matching by
1322 # accident.
1323 row = {"client_secret": None, "validated_at": None}
1324 else:
1325 raise ThreepidValidationError(400, "Unknown session_id")
1326
13001327 retrieved_client_secret = row["client_secret"]
13011328 validated_at = row["validated_at"]
1302
1303 if retrieved_client_secret != client_secret:
1304 raise ThreepidValidationError(
1305 400, "This client_secret does not match the provided session_id"
1306 )
13071329
13081330 row = self.db_pool.simple_select_one_txn(
13091331 txn,
13201342 expires = row["expires"]
13211343 next_link = row["next_link"]
13221344
1345 if retrieved_client_secret != client_secret:
1346 raise ThreepidValidationError(
1347 400, "This client_secret does not match the provided session_id"
1348 )
1349
13231350 # If the session is already validated, no need to revalidate
13241351 if validated_at:
13251352 return next_link
13401367 return next_link
13411368
13421369 # Return next_link if it exists
1343 return self.db_pool.runInteraction(
1370 return await self.db_pool.runInteraction(
13441371 "validate_threepid_session_txn", validate_threepid_session_txn
13451372 )
13461373
1347 def upsert_threepid_validation_session(
1374 async def start_or_continue_validation_session(
13481375 self,
1349 medium,
1350 address,
1351 client_secret,
1352 send_attempt,
1353 session_id,
1354 validated_at=None,
1355 ):
1356 """Upsert a threepid validation session
1357 Args:
1358 medium (str): The medium of the 3PID
1359 address (str): The address of the 3PID
1360 client_secret (str): A unique string provided by the client to
1361 help identify this validation attempt
1362 send_attempt (int): The latest send_attempt on this session
1363 session_id (str): The id of this validation session
1364 validated_at (int|None): The unix timestamp in milliseconds of
1365 when the session was marked as valid
1366 """
1367 insertion_values = {
1368 "medium": medium,
1369 "address": address,
1370 "client_secret": client_secret,
1371 }
1372
1373 if validated_at:
1374 insertion_values["validated_at"] = validated_at
1375
1376 return self.db_pool.simple_upsert(
1377 table="threepid_validation_session",
1378 keyvalues={"session_id": session_id},
1379 values={"last_send_attempt": send_attempt},
1380 insertion_values=insertion_values,
1381 desc="upsert_threepid_validation_session",
1382 )
1383
1384 def start_or_continue_validation_session(
1385 self,
1386 medium,
1387 address,
1388 session_id,
1389 client_secret,
1390 send_attempt,
1391 next_link,
1392 token,
1393 token_expires,
1394 ):
1376 medium: str,
1377 address: str,
1378 session_id: str,
1379 client_secret: str,
1380 send_attempt: int,
1381 next_link: Optional[str],
1382 token: str,
1383 token_expires: int,
1384 ) -> None:
13951385 """Creates a new threepid validation session if it does not already
13961386 exist and associates a new validation token with it
13971387
13981388 Args:
1399 medium (str): The medium of the 3PID
1400 address (str): The address of the 3PID
1401 session_id (str): The id of this validation session
1402 client_secret (str): A unique string provided by the client to
1403 help identify this validation attempt
1404 send_attempt (int): The latest send_attempt on this session
1405 next_link (str|None): The link to redirect the user to upon
1406 successful validation
1407 token (str): The validation token
1408 token_expires (int): The timestamp for which after the token
1409 will no longer be valid
1389 medium: The medium of the 3PID
1390 address: The address of the 3PID
1391 session_id: The id of this validation session
1392 client_secret: A unique string provided by the client to help
1393 identify this validation attempt
1394 send_attempt: The latest send_attempt on this session
1395 next_link: The link to redirect the user to upon successful validation
1396 token: The validation token
1397 token_expires: The timestamp for which after the token will no
1398 longer be valid
14101399 """
14111400
14121401 def start_or_continue_validation_session_txn(txn):
14351424 },
14361425 )
14371426
1438 return self.db_pool.runInteraction(
1427 await self.db_pool.runInteraction(
14391428 "start_or_continue_validation_session",
14401429 start_or_continue_validation_session_txn,
14411430 )
14421431
1443 def cull_expired_threepid_validation_tokens(self):
1432 async def cull_expired_threepid_validation_tokens(self) -> None:
14441433 """Remove threepid validation tokens with expiry dates that have passed"""
14451434
14461435 def cull_expired_threepid_validation_tokens_txn(txn, ts):
14481437 DELETE FROM threepid_validation_token WHERE
14491438 expires < ?
14501439 """
1451 return txn.execute(sql, (ts,))
1452
1453 return self.db_pool.runInteraction(
1440 txn.execute(sql, (ts,))
1441
1442 await self.db_pool.runInteraction(
14541443 "cull_expired_threepid_validation_tokens",
14551444 cull_expired_threepid_validation_tokens_txn,
14561445 self.clock.time_msec(),
1313 # limitations under the License.
1414
1515 import logging
16 from typing import Optional
1617
1718 from synapse.storage._base import SQLBaseStore
1819
2021
2122
2223 class RejectionsStore(SQLBaseStore):
23 def get_rejection_reason(self, event_id):
24 return self.db_pool.simple_select_one_onecol(
24 async def get_rejection_reason(self, event_id: str) -> Optional[str]:
25 return await self.db_pool.simple_select_one_onecol(
2526 table="rejections",
2627 retcol="reason",
2728 keyvalues={"event_id": event_id},
3333
3434 class RelationsWorkerStore(SQLBaseStore):
3535 @cached(tree=True)
36 def get_relations_for_event(
36 async def get_relations_for_event(
3737 self,
38 event_id,
39 relation_type=None,
40 event_type=None,
41 aggregation_key=None,
42 limit=5,
43 direction="b",
44 from_token=None,
45 to_token=None,
46 ):
38 event_id: str,
39 relation_type: Optional[str] = None,
40 event_type: Optional[str] = None,
41 aggregation_key: Optional[str] = None,
42 limit: int = 5,
43 direction: str = "b",
44 from_token: Optional[RelationPaginationToken] = None,
45 to_token: Optional[RelationPaginationToken] = None,
46 ) -> PaginationChunk:
4747 """Get a list of relations for an event, ordered by topological ordering.
4848
4949 Args:
50 event_id (str): Fetch events that relate to this event ID.
51 relation_type (str|None): Only fetch events with this relation
52 type, if given.
53 event_type (str|None): Only fetch events with this event type, if
54 given.
55 aggregation_key (str|None): Only fetch events with this aggregation
56 key, if given.
57 limit (int): Only fetch the most recent `limit` events.
58 direction (str): Whether to fetch the most recent first (`"b"`) or
59 the oldest first (`"f"`).
60 from_token (RelationPaginationToken|None): Fetch rows from the given
61 token, or from the start if None.
62 to_token (RelationPaginationToken|None): Fetch rows up to the given
63 token, or up to the end if None.
50 event_id: Fetch events that relate to this event ID.
51 relation_type: Only fetch events with this relation type, if given.
52 event_type: Only fetch events with this event type, if given.
53 aggregation_key: Only fetch events with this aggregation key, if given.
54 limit: Only fetch the most recent `limit` events.
55 direction: Whether to fetch the most recent first (`"b"`) or the
56 oldest first (`"f"`).
57 from_token: Fetch rows from the given token, or from the start if None.
58 to_token: Fetch rows up to the given token, or up to the end if None.
6459
6560 Returns:
66 Deferred[PaginationChunk]: List of event IDs that match relations
67 requested. The rows are of the form `{"event_id": "..."}`.
61 List of event IDs that match relations requested. The rows are of
62 the form `{"event_id": "..."}`.
6863 """
6964
7065 where_clause = ["relates_to_id = ?"]
130125 chunk=list(events[:limit]), next_batch=next_batch, prev_batch=from_token
131126 )
132127
133 return self.db_pool.runInteraction(
128 return await self.db_pool.runInteraction(
134129 "get_recent_references_for_event", _get_recent_references_for_event_txn
135130 )
136131
137132 @cached(tree=True)
138 def get_aggregation_groups_for_event(
133 async def get_aggregation_groups_for_event(
139134 self,
140 event_id,
141 event_type=None,
142 limit=5,
143 direction="b",
144 from_token=None,
145 to_token=None,
146 ):
135 event_id: str,
136 event_type: Optional[str] = None,
137 limit: int = 5,
138 direction: str = "b",
139 from_token: Optional[AggregationPaginationToken] = None,
140 to_token: Optional[AggregationPaginationToken] = None,
141 ) -> PaginationChunk:
147142 """Get a list of annotations on the event, grouped by event type and
148143 aggregation key, sorted by count.
149144
151146 on an event.
152147
153148 Args:
154 event_id (str): Fetch events that relate to this event ID.
155 event_type (str|None): Only fetch events with this event type, if
156 given.
157 limit (int): Only fetch the `limit` groups.
158 direction (str): Whether to fetch the highest count first (`"b"`) or
149 event_id: Fetch events that relate to this event ID.
150 event_type: Only fetch events with this event type, if given.
151 limit: Only fetch the `limit` groups.
152 direction: Whether to fetch the highest count first (`"b"`) or
159153 the lowest count first (`"f"`).
160 from_token (AggregationPaginationToken|None): Fetch rows from the
161 given token, or from the start if None.
162 to_token (AggregationPaginationToken|None): Fetch rows up to the
163 given token, or up to the end if None.
164
154 from_token: Fetch rows from the given token, or from the start if None.
155 to_token: Fetch rows up to the given token, or up to the end if None.
165156
166157 Returns:
167 Deferred[PaginationChunk]: List of groups of annotations that
168 match. Each row is a dict with `type`, `key` and `count` fields.
158 List of groups of annotations that match. Each row is a dict with
159 `type`, `key` and `count` fields.
169160 """
170161
171162 where_clause = ["relates_to_id = ?", "relation_type = ?"]
224215 chunk=list(events[:limit]), next_batch=next_batch, prev_batch=from_token
225216 )
226217
227 return self.db_pool.runInteraction(
218 return await self.db_pool.runInteraction(
228219 "get_aggregation_groups_for_event", _get_aggregation_groups_for_event_txn
229220 )
230221
278269
279270 return await self.get_event(edit_id, allow_none=True)
280271
281 def has_user_annotated_event(self, parent_id, event_type, aggregation_key, sender):
272 async def has_user_annotated_event(
273 self, parent_id: str, event_type: str, aggregation_key: str, sender: str
274 ) -> bool:
282275 """Check if a user has already annotated an event with the same key
283276 (e.g. already liked an event).
284277
285278 Args:
286 parent_id (str): The event being annotated
287 event_type (str): The event type of the annotation
288 aggregation_key (str): The aggregation key of the annotation
289 sender (str): The sender of the annotation
279 parent_id: The event being annotated
280 event_type: The event type of the annotation
281 aggregation_key: The aggregation key of the annotation
282 sender: The sender of the annotation
290283
291284 Returns:
292 Deferred[bool]
285 True if the event is already annotated.
293286 """
294287
295288 sql = """
318311
319312 return bool(txn.fetchone())
320313
321 return self.db_pool.runInteraction(
314 return await self.db_pool.runInteraction(
322315 "get_if_user_has_annotated_event", _get_if_user_has_annotated_event
323316 )
324317
2020 from enum import Enum
2121 from typing import Any, Dict, List, Optional, Tuple
2222
23 from canonicaljson import json
24
2523 from synapse.api.constants import EventTypes
2624 from synapse.api.errors import StoreError
2725 from synapse.api.room_versions import RoomVersion, RoomVersions
2826 from synapse.storage._base import SQLBaseStore, db_to_json
2927 from synapse.storage.database import DatabasePool, LoggingTransaction
3028 from synapse.storage.databases.main.search import SearchStore
31 from synapse.types import ThirdPartyInstanceID
29 from synapse.types import JsonDict, ThirdPartyInstanceID
30 from synapse.util import json_encoder
3231 from synapse.util.caches.descriptors import cached
3332
3433 logger = logging.getLogger(__name__)
3534
36
37 OpsLevel = collections.namedtuple(
38 "OpsLevel", ("ban_level", "kick_level", "redact_level")
39 )
4035
4136 RatelimitOverride = collections.namedtuple(
4237 "RatelimitOverride", ("messages_per_second", "burst_count")
7772
7873 self.config = hs.config
7974
80 def get_room(self, room_id):
75 async def get_room(self, room_id: str) -> dict:
8176 """Retrieve a room.
8277
8378 Args:
84 room_id (str): The ID of the room to retrieve.
79 room_id: The ID of the room to retrieve.
8580 Returns:
8681 A dict containing the room information, or None if the room is unknown.
8782 """
88 return self.db_pool.simple_select_one(
83 return await self.db_pool.simple_select_one(
8984 table="rooms",
9085 keyvalues={"room_id": room_id},
9186 retcols=("room_id", "is_public", "creator"),
9388 allow_none=True,
9489 )
9590
96 def get_room_with_stats(self, room_id: str):
91 async def get_room_with_stats(self, room_id: str) -> Optional[Dict[str, Any]]:
9792 """Retrieve room with statistics.
9893
9994 Args:
125120 res["public"] = bool(res["public"])
126121 return res
127122
128 return self.db_pool.runInteraction(
123 return await self.db_pool.runInteraction(
129124 "get_room_with_stats", get_room_with_stats_txn, room_id
130125 )
131126
132 def get_public_room_ids(self):
133 return self.db_pool.simple_select_onecol(
127 async def get_public_room_ids(self) -> List[str]:
128 return await self.db_pool.simple_select_onecol(
134129 table="rooms",
135130 keyvalues={"is_public": True},
136131 retcol="room_id",
137132 desc="get_public_room_ids",
138133 )
139134
140 def count_public_rooms(self, network_tuple, ignore_non_federatable):
135 async def count_public_rooms(
136 self,
137 network_tuple: Optional[ThirdPartyInstanceID],
138 ignore_non_federatable: bool,
139 ) -> int:
141140 """Counts the number of public rooms as tracked in the room_stats_current
142141 and room_stats_state table.
143142
144143 Args:
145 network_tuple (ThirdPartyInstanceID|None)
146 ignore_non_federatable (bool): If true filters out non-federatable rooms
144 network_tuple
145 ignore_non_federatable: If true filters out non-federatable rooms
147146 """
148147
149148 def _count_public_rooms_txn(txn):
187186 txn.execute(sql, query_args)
188187 return txn.fetchone()[0]
189188
190 return self.db_pool.runInteraction(
189 return await self.db_pool.runInteraction(
191190 "count_public_rooms", _count_public_rooms_txn
192191 )
193192
334333 return ret_val
335334
336335 @cached(max_entries=10000)
337 def is_room_blocked(self, room_id):
338 return self.db_pool.simple_select_one_onecol(
336 async def is_room_blocked(self, room_id: str) -> Optional[bool]:
337 return await self.db_pool.simple_select_one_onecol(
339338 table="blocked_rooms",
340339 keyvalues={"room_id": room_id},
341340 retcol="1",
590589
591590 return row
592591
593 def get_media_mxcs_in_room(self, room_id):
592 async def get_media_mxcs_in_room(self, room_id: str) -> Tuple[List[str], List[str]]:
594593 """Retrieves all the local and remote media MXC URIs in a given room
595594
596595 Args:
597 room_id (str)
596 room_id
598597
599598 Returns:
600 The local and remote media as a lists of tuples where the key is
601 the hostname and the value is the media ID.
599 The local and remote media as a lists of the media IDs.
602600 """
603601
604602 def _get_media_mxcs_in_room_txn(txn):
614612
615613 return local_media_mxcs, remote_media_mxcs
616614
617 return self.db_pool.runInteraction(
615 return await self.db_pool.runInteraction(
618616 "get_media_ids_in_room", _get_media_mxcs_in_room_txn
619617 )
620618
621 def quarantine_media_ids_in_room(self, room_id, quarantined_by):
619 async def quarantine_media_ids_in_room(
620 self, room_id: str, quarantined_by: str
621 ) -> int:
622622 """For a room loops through all events with media and quarantines
623623 the associated media
624624 """
631631 txn, local_mxcs, remote_mxcs, quarantined_by
632632 )
633633
634 return self.db_pool.runInteraction(
634 return await self.db_pool.runInteraction(
635635 "quarantine_media_in_room", _quarantine_media_in_room_txn
636636 )
637637
694694
695695 return local_media_mxcs, remote_media_mxcs
696696
697 def quarantine_media_by_id(
697 async def quarantine_media_by_id(
698698 self, server_name: str, media_id: str, quarantined_by: str,
699 ):
699 ) -> int:
700700 """quarantines a single local or remote media id
701701
702702 Args:
715715 txn, local_mxcs, remote_mxcs, quarantined_by
716716 )
717717
718 return self.db_pool.runInteraction(
718 return await self.db_pool.runInteraction(
719719 "quarantine_media_by_user", _quarantine_media_by_id_txn
720720 )
721721
722 def quarantine_media_ids_by_user(self, user_id: str, quarantined_by: str):
722 async def quarantine_media_ids_by_user(
723 self, user_id: str, quarantined_by: str
724 ) -> int:
723725 """quarantines all local media associated with a single user
724726
725727 Args:
731733 local_media_ids = self._get_media_ids_by_user_txn(txn, user_id)
732734 return self._quarantine_media_txn(txn, local_media_ids, [], quarantined_by)
733735
734 return self.db_pool.runInteraction(
736 return await self.db_pool.runInteraction(
735737 "quarantine_media_by_user", _quarantine_media_by_user_txn
736738 )
737739
11331135 },
11341136 )
11351137
1136 with self._public_room_id_gen.get_next() as next_id:
1138 with await self._public_room_id_gen.get_next() as next_id:
11371139 await self.db_pool.runInteraction(
11381140 "store_room_txn", store_room_txn, next_id
11391141 )
12001202 },
12011203 )
12021204
1203 with self._public_room_id_gen.get_next() as next_id:
1205 with await self._public_room_id_gen.get_next() as next_id:
12041206 await self.db_pool.runInteraction(
12051207 "set_room_is_public", set_room_is_public_txn, next_id
12061208 )
12801282 },
12811283 )
12821284
1283 with self._public_room_id_gen.get_next() as next_id:
1285 with await self._public_room_id_gen.get_next() as next_id:
12841286 await self.db_pool.runInteraction(
12851287 "set_room_is_public_appservice",
12861288 set_room_is_public_appservice_txn,
12881290 )
12891291 self.hs.get_notifier().on_new_replication_data()
12901292
1291 def get_room_count(self):
1292 """Retrieve a list of all rooms
1293 async def get_room_count(self) -> int:
1294 """Retrieve the total number of rooms.
12931295 """
12941296
12951297 def f(txn):
12981300 row = txn.fetchone()
12991301 return row[0] or 0
13001302
1301 return self.db_pool.runInteraction("get_rooms", f)
1302
1303 def add_event_report(
1304 self, room_id, event_id, user_id, reason, content, received_ts
1305 ):
1303 return await self.db_pool.runInteraction("get_rooms", f)
1304
1305 async def add_event_report(
1306 self,
1307 room_id: str,
1308 event_id: str,
1309 user_id: str,
1310 reason: str,
1311 content: JsonDict,
1312 received_ts: int,
1313 ) -> None:
13061314 next_id = self._event_reports_id_gen.get_next()
1307 return self.db_pool.simple_insert(
1315 await self.db_pool.simple_insert(
13081316 table="event_reports",
13091317 values={
13101318 "id": next_id,
13131321 "event_id": event_id,
13141322 "user_id": user_id,
13151323 "reason": reason,
1316 "content": json.dumps(content),
1324 "content": json_encoder.encode(content),
13171325 },
13181326 desc="add_event_report",
13191327 )
1414 # limitations under the License.
1515
1616 import logging
17 from typing import TYPE_CHECKING, Awaitable, Iterable, List, Optional, Set
18
19 from twisted.internet import defer
17 from typing import TYPE_CHECKING, Dict, FrozenSet, Iterable, List, Optional, Set
2018
2119 from synapse.api.constants import EventTypes, Membership
2220 from synapse.events import EventBase
9189 lambda: self._known_servers_count,
9290 )
9391
94 @defer.inlineCallbacks
95 def _count_known_servers(self):
92 async def _count_known_servers(self):
9693 """
9794 Count the servers that this server knows about.
9895
120117 txn.execute(query)
121118 return list(txn)[0][0]
122119
123 count = yield self.db_pool.runInteraction("get_known_servers", _transact)
120 count = await self.db_pool.runInteraction("get_known_servers", _transact)
124121
125122 # We always know about ourselves, even if we have nothing in
126123 # room_memberships (for example, the server is new).
154151 )
155152
156153 @cached(max_entries=100000, iterable=True)
157 def get_users_in_room(self, room_id: str):
158 return self.db_pool.runInteraction(
154 async def get_users_in_room(self, room_id: str) -> List[str]:
155 return await self.db_pool.runInteraction(
159156 "get_users_in_room", self.get_users_in_room_txn, room_id
160157 )
161158
182179 return [r[0] for r in txn]
183180
184181 @cached(max_entries=100000)
185 def get_room_summary(self, room_id: str):
182 async def get_room_summary(self, room_id: str) -> Dict[str, MemberSummary]:
186183 """ Get the details of a room roughly suitable for use by the room
187184 summary extension to /sync. Useful when lazy loading room members.
188185 Args:
189186 room_id: The room ID to query
190187 Returns:
191 Deferred[dict[str, MemberSummary]:
192 dict of membership states, pointing to a MemberSummary named tuple.
188 dict of membership states, pointing to a MemberSummary named tuple.
193189 """
194190
195191 def _get_room_summary_txn(txn):
263259
264260 return res
265261
266 return self.db_pool.runInteraction("get_room_summary", _get_room_summary_txn)
262 return await self.db_pool.runInteraction(
263 "get_room_summary", _get_room_summary_txn
264 )
267265
268266 @cached()
269 def get_invited_rooms_for_local_user(self, user_id: str) -> Awaitable[RoomsForUser]:
267 async def get_invited_rooms_for_local_user(self, user_id: str) -> RoomsForUser:
270268 """Get all the rooms the *local* user is invited to.
271269
272270 Args:
273271 user_id: The user ID.
274272
275273 Returns:
276 A awaitable list of RoomsForUser.
277 """
278
279 return self.get_rooms_for_local_user_where_membership_is(
274 A list of RoomsForUser.
275 """
276
277 return await self.get_rooms_for_local_user_where_membership_is(
280278 user_id, [Membership.INVITE]
281279 )
282280
299297 return None
300298
301299 async def get_rooms_for_local_user_where_membership_is(
302 self, user_id: str, membership_list: List[str]
303 ) -> Optional[List[RoomsForUser]]:
300 self, user_id: str, membership_list: Collection[str]
301 ) -> List[RoomsForUser]:
304302 """Get all the rooms for this *local* user where the membership for this user
305303 matches one in the membership list.
306304
315313 The RoomsForUser that the user matches the membership types.
316314 """
317315 if not membership_list:
318 return None
316 return []
319317
320318 rooms = await self.db_pool.runInteraction(
321319 "get_rooms_for_local_user_where_membership_is",
359357 return results
360358
361359 @cached(max_entries=500000, iterable=True)
362 def get_rooms_for_user_with_stream_ordering(self, user_id: str):
360 async def get_rooms_for_user_with_stream_ordering(
361 self, user_id: str
362 ) -> FrozenSet[GetRoomsForUserWithStreamOrdering]:
363363 """Returns a set of room_ids the user is currently joined to.
364364
365365 If a remote user only returns rooms this server is currently
369369 user_id
370370
371371 Returns:
372 Deferred[frozenset[GetRoomsForUserWithStreamOrdering]]: Returns
373 the rooms the user is in currently, along with the stream ordering
374 of the most recent join for that user and room.
375 """
376 return self.db_pool.runInteraction(
372 Returns the rooms the user is in currently, along with the stream
373 ordering of the most recent join for that user and room.
374 """
375 return await self.db_pool.runInteraction(
377376 "get_rooms_for_user_with_stream_ordering",
378377 self._get_rooms_for_user_with_stream_ordering_txn,
379378 user_id,
380379 )
381380
382 def _get_rooms_for_user_with_stream_ordering_txn(self, txn, user_id: str):
381 def _get_rooms_for_user_with_stream_ordering_txn(
382 self, txn, user_id: str
383 ) -> FrozenSet[GetRoomsForUserWithStreamOrdering]:
383384 # We use `current_state_events` here and not `local_current_membership`
384385 # as a) this gets called with remote users and b) this only gets called
385386 # for rooms the server is participating in.
406407 """
407408
408409 txn.execute(sql, (user_id, Membership.JOIN))
409 results = frozenset(GetRoomsForUserWithStreamOrdering(*row) for row in txn)
410
411 return results
410 return frozenset(GetRoomsForUserWithStreamOrdering(*row) for row in txn)
412411
413412 async def get_users_server_still_shares_room_with(
414413 self, user_ids: Collection[str]
588587 raise NotImplementedError()
589588
590589 @cachedList(
591 cached_method_name="_get_joined_profile_from_event_id",
592 list_name="event_ids",
593 inlineCallbacks=True,
590 cached_method_name="_get_joined_profile_from_event_id", list_name="event_ids",
594591 )
595 def _get_joined_profiles_from_event_ids(self, event_ids: Iterable[str]):
592 async def _get_joined_profiles_from_event_ids(self, event_ids: Iterable[str]):
596593 """For given set of member event_ids check if they point to a join
597594 event and if so return the associated user and profile info.
598595
600597 event_ids: The member event IDs to lookup
601598
602599 Returns:
603 Deferred[dict[str, Tuple[str, ProfileInfo]|None]]: Map from event ID
600 dict[str, Tuple[str, ProfileInfo]|None]: Map from event ID
604601 to `user_id` and ProfileInfo (or None if not join event).
605602 """
606603
607 rows = yield self.db_pool.simple_select_many_batch(
604 rows = await self.db_pool.simple_select_many_batch(
608605 table="room_memberships",
609606 column="event_id",
610607 iterable=event_ids,
715712 return count == 0
716713
717714 @cached()
718 def get_forgotten_rooms_for_user(self, user_id: str):
715 async def get_forgotten_rooms_for_user(self, user_id: str) -> Set[str]:
719716 """Gets all rooms the user has forgotten.
720717
721718 Args:
722 user_id
719 user_id: The user ID to query the rooms of.
723720
724721 Returns:
725 Deferred[set[str]]
722 The forgotten rooms.
726723 """
727724
728725 def _get_forgotten_rooms_for_user_txn(txn):
748745 txn.execute(sql, (user_id,))
749746 return {row[0] for row in txn if row[1] == 0}
750747
751 return self.db_pool.runInteraction(
748 return await self.db_pool.runInteraction(
752749 "get_forgotten_rooms_for_user", _get_forgotten_rooms_for_user_txn
753750 )
754751
771768
772769 return set(room_ids)
773770
774 def get_membership_from_event_ids(
771 async def get_membership_from_event_ids(
775772 self, member_event_ids: Iterable[str]
776773 ) -> List[dict]:
777774 """Get user_id and membership of a set of event IDs.
778775 """
779776
780 return self.db_pool.simple_select_many_batch(
777 return await self.db_pool.simple_select_many_batch(
781778 table="room_memberships",
782779 column="event_id",
783780 iterable=member_event_ids,
977974 def __init__(self, database: DatabasePool, db_conn, hs):
978975 super(RoomMemberStore, self).__init__(database, db_conn, hs)
979976
980 def forget(self, user_id: str, room_id: str):
977 async def forget(self, user_id: str, room_id: str) -> None:
981978 """Indicate that user_id wishes to discard history for room_id."""
982979
983980 def f(txn):
998995 txn, self.get_forgotten_rooms_for_user, (user_id,)
999996 )
1000997
1001 return self.db_pool.runInteraction("forget_membership", f)
1002
1003
1004 class _JoinedHostsCache(object):
998 await self.db_pool.runInteraction("forget_membership", f)
999
1000
1001 class _JoinedHostsCache:
10051002 """Cache for joined hosts in a room that is optimised to handle updates
10061003 via state deltas.
10071004 """
0 /* Copyright 2020 The Matrix.org Foundation C.I.C
1 *
2 * Licensed under the Apache License, Version 2.0 (the "License");
3 * you may not use this file except in compliance with the License.
4 * You may obtain a copy of the License at
5 *
6 * http://www.apache.org/licenses/LICENSE-2.0
7 *
8 * Unless required by applicable law or agreed to in writing, software
9 * distributed under the License is distributed on an "AS IS" BASIS,
10 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
11 * See the License for the specific language governing permissions and
12 * limitations under the License.
13 */
14
15 -- A table of the IP address and user-agent used to complete each step of a
16 -- user-interactive authentication session.
17 CREATE TABLE IF NOT EXISTS ui_auth_sessions_ips(
18 session_id TEXT NOT NULL,
19 ip TEXT NOT NULL,
20 user_agent TEXT NOT NULL,
21 UNIQUE (session_id, ip, user_agent),
22 FOREIGN KEY (session_id)
23 REFERENCES ui_auth_sessions (session_id)
24 );
0 /* Copyright 2020 The Matrix.org Foundation C.I.C
1 *
2 * Licensed under the Apache License, Version 2.0 (the "License");
3 * you may not use this file except in compliance with the License.
4 * You may obtain a copy of the License at
5 *
6 * http://www.apache.org/licenses/LICENSE-2.0
7 *
8 * Unless required by applicable law or agreed to in writing, software
9 * distributed under the License is distributed on an "AS IS" BASIS,
10 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
11 * See the License for the specific language governing permissions and
12 * limitations under the License.
13 */
14
15 -- A shadow-banned user may be told that their requests succeeded when they were
16 -- actually ignored.
17 ALTER TABLE users ADD COLUMN shadow_banned BOOLEAN;
0 /* Copyright 2020 The Matrix.org Foundation C.I.C.
1 *
2 * Licensed under the Apache License, Version 2.0 (the "License");
3 * you may not use this file except in compliance with the License.
4 * You may obtain a copy of the License at
5 *
6 * http://www.apache.org/licenses/LICENSE-2.0
7 *
8 * Unless required by applicable law or agreed to in writing, software
9 * distributed under the License is distributed on an "AS IS" BASIS,
10 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
11 * See the License for the specific language governing permissions and
12 * limitations under the License.
13 */
14
15 -- This table is no longer used.
16 DROP TABLE IF EXISTS presence_allow_inbound;
0 /* Copyright 2020 The Matrix.org Foundation C.I.C.
1 *
2 * Licensed under the Apache License, Version 2.0 (the "License");
3 * you may not use this file except in compliance with the License.
4 * You may obtain a copy of the License at
5 *
6 * http://www.apache.org/licenses/LICENSE-2.0
7 *
8 * Unless required by applicable law or agreed to in writing, software
9 * distributed under the License is distributed on an "AS IS" BASIS,
10 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
11 * See the License for the specific language governing permissions and
12 * limitations under the License.
13 */
14
15 -- We're hijacking the push actions to store unread messages and unread counts (specified
16 -- in MSC2654) because doing otherwise would result in either performance issues or
17 -- reimplementing a consequent bit of the push actions.
18
19 -- Add columns to event_push_actions and event_push_actions_staging to track unread
20 -- messages and calculate unread counts.
21 ALTER TABLE event_push_actions_staging ADD COLUMN unread SMALLINT;
22 ALTER TABLE event_push_actions ADD COLUMN unread SMALLINT;
23
24 -- Add column to event_push_summary
25 ALTER TABLE event_push_summary ADD COLUMN unread_count BIGINT;
1515 import logging
1616 import re
1717 from collections import namedtuple
18 from typing import List, Optional
18 from typing import List, Optional, Set
1919
2020 from synapse.api.errors import SynapseError
21 from synapse.events import EventBase
2122 from synapse.storage._base import SQLBaseStore, db_to_json, make_in_list_sql_clause
2223 from synapse.storage.database import DatabasePool
2324 from synapse.storage.databases.main.events_worker import EventRedactBehaviour
619620 "count": count,
620621 }
621622
622 def _find_highlights_in_postgres(self, search_query, events):
623 async def _find_highlights_in_postgres(
624 self, search_query: str, events: List[EventBase]
625 ) -> Set[str]:
623626 """Given a list of events and a search term, return a list of words
624627 that match from the content of the event.
625628
627630 highlight the matching parts.
628631
629632 Args:
630 search_query (str)
631 events (list): A list of events
633 search_query
634 events: A list of events
632635
633636 Returns:
634 deferred : A set of strings.
637 A set of strings.
635638 """
636639
637640 def f(txn):
684687
685688 return highlight_words
686689
687 return self.db_pool.runInteraction("_find_highlights", f)
690 return await self.db_pool.runInteraction("_find_highlights", f)
688691
689692
690693 def _to_postgres_options(options_dict):
1212 # See the License for the specific language governing permissions and
1313 # limitations under the License.
1414
15 from typing import Dict, Iterable, List, Tuple
16
1517 from unpaddedbase64 import encode_base64
1618
1719 from synapse.storage._base import SQLBaseStore
20 from synapse.storage.types import Cursor
1821 from synapse.util.caches.descriptors import cached, cachedList
1922
2023
2831 @cachedList(
2932 cached_method_name="get_event_reference_hash", list_name="event_ids", num_args=1
3033 )
31 def get_event_reference_hashes(self, event_ids):
34 async def get_event_reference_hashes(
35 self, event_ids: Iterable[str]
36 ) -> Dict[str, Dict[str, bytes]]:
37 """Get all hashes for given events.
38
39 Args:
40 event_ids: The event IDs to get hashes for.
41
42 Returns:
43 A mapping of event ID to a mapping of algorithm to hash.
44 """
45
3246 def f(txn):
3347 return {
3448 event_id: self._get_event_reference_hashes_txn(txn, event_id)
3549 for event_id in event_ids
3650 }
3751
38 return self.db_pool.runInteraction("get_event_reference_hashes", f)
52 return await self.db_pool.runInteraction("get_event_reference_hashes", f)
3953
40 async def add_event_hashes(self, event_ids):
54 async def add_event_hashes(
55 self, event_ids: Iterable[str]
56 ) -> List[Tuple[str, Dict[str, str]]]:
57 """
58
59 Args:
60 event_ids: The event IDs
61
62 Returns:
63 A list of tuples of event ID and a mapping of algorithm to base-64 encoded hash.
64 """
4165 hashes = await self.get_event_reference_hashes(event_ids)
4266 hashes = {
4367 e_id: {k: encode_base64(v) for k, v in h.items() if k == "sha256"}
4670
4771 return list(hashes.items())
4872
49 def _get_event_reference_hashes_txn(self, txn, event_id):
73 def _get_event_reference_hashes_txn(
74 self, txn: Cursor, event_id: str
75 ) -> Dict[str, bytes]:
5076 """Get all the hashes for a given PDU.
5177 Args:
52 txn (cursor):
53 event_id (str): Id for the Event.
78 txn:
79 event_id: Id for the Event.
5480 Returns:
55 A dict[unicode, bytes] of algorithm -> hash.
81 A mapping of algorithm -> hash.
5682 """
5783 query = (
5884 "SELECT algorithm, hash"
2626 from synapse.storage.databases.main.events_worker import EventsWorkerStore
2727 from synapse.storage.databases.main.roommember import RoomMemberWorkerStore
2828 from synapse.storage.state import StateFilter
29 from synapse.types import StateMap
2930 from synapse.util.caches import intern_string
3031 from synapse.util.caches.descriptors import cached, cachedList
3132
162163 return create_event
163164
164165 @cached(max_entries=100000, iterable=True)
165 def get_current_state_ids(self, room_id):
166 async def get_current_state_ids(self, room_id: str) -> StateMap[str]:
166167 """Get the current state event ids for a room based on the
167168 current_state_events table.
168169
169170 Args:
170 room_id (str)
171 room_id: The room to get the state IDs of.
171172
172173 Returns:
173 deferred: dict of (type, state_key) -> event_id
174 The current state of the room.
174175 """
175176
176177 def _get_current_state_ids_txn(txn):
183184
184185 return {(intern_string(r[0]), intern_string(r[1])): r[2] for r in txn}
185186
186 return self.db_pool.runInteraction(
187 return await self.db_pool.runInteraction(
187188 "get_current_state_ids", _get_current_state_ids_txn
188189 )
189190
190191 # FIXME: how should this be cached?
191 def get_filtered_current_state_ids(
192 async def get_filtered_current_state_ids(
192193 self, room_id: str, state_filter: StateFilter = StateFilter.all()
193 ):
194 ) -> StateMap[str]:
194195 """Get the current state event of a given type for a room based on the
195196 current_state_events table. This may not be as up-to-date as the result
196197 of doing a fresh state resolution as per state_handler.get_current_state
201202 from the database.
202203
203204 Returns:
204 defer.Deferred[StateMap[str]]: Map from type/state_key to event ID.
205 Map from type/state_key to event ID.
205206 """
206207
207208 where_clause, where_args = state_filter.make_sql_filter_clause()
208209
209210 if not where_clause:
210211 # We delegate to the cached version
211 return self.get_current_state_ids(room_id)
212 return await self.get_current_state_ids(room_id)
212213
213214 def _get_filtered_current_state_ids_txn(txn):
214215 results = {}
230231
231232 return results
232233
233 return self.db_pool.runInteraction(
234 return await self.db_pool.runInteraction(
234235 "get_filtered_current_state_ids", _get_filtered_current_state_ids_txn
235236 )
236237
259260 return event.content.get("canonical_alias")
260261
261262 @cached(max_entries=50000)
262 def _get_state_group_for_event(self, event_id):
263 return self.db_pool.simple_select_one_onecol(
263 async def _get_state_group_for_event(self, event_id: str) -> Optional[int]:
264 return await self.db_pool.simple_select_one_onecol(
264265 table="event_to_state_groups",
265266 keyvalues={"event_id": event_id},
266267 retcol="state_group",
272273 cached_method_name="_get_state_group_for_event",
273274 list_name="event_ids",
274275 num_args=1,
275 inlineCallbacks=True,
276276 )
277 def _get_state_group_for_events(self, event_ids):
277 async def _get_state_group_for_events(self, event_ids):
278278 """Returns mapping event_id -> state_group
279279 """
280 rows = yield self.db_pool.simple_select_many_batch(
280 rows = await self.db_pool.simple_select_many_batch(
281281 table="event_to_state_groups",
282282 column="event_id",
283283 iterable=event_ids,
1313 # limitations under the License.
1414
1515 import logging
16
17 from twisted.internet import defer
16 from typing import Any, Dict, List, Tuple
1817
1918 from synapse.storage._base import SQLBaseStore
2019
2221
2322
2423 class StateDeltasStore(SQLBaseStore):
25 def get_current_state_deltas(self, prev_stream_id: int, max_stream_id: int):
24 async def get_current_state_deltas(
25 self, prev_stream_id: int, max_stream_id: int
26 ) -> Tuple[int, List[Dict[str, Any]]]:
2627 """Fetch a list of room state changes since the given stream id
2728
2829 Each entry in the result contains the following fields:
3637 if it's new state.
3738
3839 Args:
39 prev_stream_id (int): point to get changes since (exclusive)
40 max_stream_id (int): the point that we know has been correctly persisted
40 prev_stream_id: point to get changes since (exclusive)
41 max_stream_id: the point that we know has been correctly persisted
4142 - ie, an upper limit to return changes from.
4243
4344 Returns:
44 Deferred[tuple[int, list[dict]]: A tuple consisting of:
45 A tuple consisting of:
4546 - the stream id which these results go up to
4647 - list of current_state_delta_stream rows. If it is empty, we are
4748 up to date.
5758 # if the CSDs haven't changed between prev_stream_id and now, we
5859 # know for certain that they haven't changed between prev_stream_id and
5960 # max_stream_id.
60 return defer.succeed((max_stream_id, []))
61 return (max_stream_id, [])
6162
6263 def get_current_state_deltas_txn(txn):
6364 # First we calculate the max stream id that will give us less than
101102 txn.execute(sql, (prev_stream_id, clipped_stream_id))
102103 return clipped_stream_id, self.db_pool.cursor_to_dict(txn)
103104
104 return self.db_pool.runInteraction(
105 return await self.db_pool.runInteraction(
105106 "get_current_state_deltas", get_current_state_deltas_txn
106107 )
107108
113114 retcol="COALESCE(MAX(stream_id), -1)",
114115 )
115116
116 def get_max_stream_id_in_current_state_deltas(self):
117 return self.db_pool.runInteraction(
117 async def get_max_stream_id_in_current_state_deltas(self):
118 return await self.db_pool.runInteraction(
118119 "get_max_stream_id_in_current_state_deltas",
119120 self._get_max_stream_id_in_current_state_deltas_txn,
120121 )
1414 # limitations under the License.
1515
1616 import logging
17 from collections import Counter
1718 from itertools import chain
18 from typing import Tuple
19 from typing import Any, Dict, List, Optional, Tuple
1920
2021 from twisted.internet.defer import DeferredLock
2122
210211
211212 return len(rooms_to_work_on)
212213
213 def get_stats_positions(self):
214 async def get_stats_positions(self) -> int:
214215 """
215216 Returns the stats processor positions.
216217 """
217 return self.db_pool.simple_select_one_onecol(
218 return await self.db_pool.simple_select_one_onecol(
218219 table="stats_incremental_position",
219220 keyvalues={},
220221 retcol="stream_id",
221222 desc="stats_incremental_position",
222223 )
223224
224 def update_room_state(self, room_id, fields):
225 """
225 async def update_room_state(self, room_id: str, fields: Dict[str, Any]) -> None:
226 """Update the state of a room.
227
228 fields can contain the following keys with string values:
229 * join_rules
230 * history_visibility
231 * encryption
232 * name
233 * topic
234 * avatar
235 * canonical_alias
236
237 A is_federatable key can also be included with a boolean value.
238
226239 Args:
227 room_id (str)
228 fields (dict[str:Any])
229 """
230
231 # For whatever reason some of the fields may contain null bytes, which
232 # postgres isn't a fan of, so we replace those fields with null.
240 room_id: The room ID to update the state of.
241 fields: The fields to update. This can include a partial list of the
242 above fields to only update some room information.
243 """
244 # Ensure that the values to update are valid, they should be strings and
245 # not contain any null bytes.
246 #
247 # Invalid data gets overwritten with null.
248 #
249 # Note that a missing value should not be overwritten (it keeps the
250 # previous value).
251 sentinel = object()
233252 for col in (
234253 "join_rules",
235254 "history_visibility",
239258 "avatar",
240259 "canonical_alias",
241260 ):
242 field = fields.get(col)
243 if field and "\0" in field:
261 field = fields.get(col, sentinel)
262 if field is not sentinel and (not isinstance(field, str) or "\0" in field):
244263 fields[col] = None
245264
246 return self.db_pool.simple_upsert(
265 await self.db_pool.simple_upsert(
247266 table="room_stats_state",
248267 keyvalues={"room_id": room_id},
249268 values=fields,
250269 desc="update_room_state",
251270 )
252271
253 def get_statistics_for_subject(self, stats_type, stats_id, start, size=100):
272 async def get_statistics_for_subject(
273 self, stats_type: str, stats_id: str, start: str, size: int = 100
274 ) -> List[dict]:
254275 """
255276 Get statistics for a given subject.
256277
257278 Args:
258 stats_type (str): The type of subject
259 stats_id (str): The ID of the subject (e.g. room_id or user_id)
260 start (int): Pagination start. Number of entries, not timestamp.
261 size (int): How many entries to return.
279 stats_type: The type of subject
280 stats_id: The ID of the subject (e.g. room_id or user_id)
281 start: Pagination start. Number of entries, not timestamp.
282 size: How many entries to return.
262283
263284 Returns:
264 Deferred[list[dict]], where the dict has the keys of
285 A list of dicts, where the dict has the keys of
265286 ABSOLUTE_STATS_FIELDS[stats_type], and "bucket_size" and "end_ts".
266287 """
267 return self.db_pool.runInteraction(
288 return await self.db_pool.runInteraction(
268289 "get_statistics_for_subject",
269290 self._get_statistics_for_subject_txn,
270291 stats_type,
299320 return slice_list
300321
301322 @cached()
302 def get_earliest_token_for_stats(self, stats_type, id):
323 async def get_earliest_token_for_stats(self, stats_type: str, id: str) -> int:
303324 """
304325 Fetch the "earliest token". This is used by the room stats delta
305326 processor to ignore deltas that have been processed between the
307328 being calculated.
308329
309330 Returns:
310 Deferred[int]
331 The earliest token.
311332 """
312333 table, id_col = TYPE_TO_TABLE[stats_type]
313334
314 return self.db_pool.simple_select_one_onecol(
335 return await self.db_pool.simple_select_one_onecol(
315336 "%s_current" % (table,),
316337 keyvalues={id_col: id},
317338 retcol="completed_delta_stream_id",
318339 allow_none=True,
319340 )
320341
321 def bulk_update_stats_delta(self, ts, updates, stream_id):
342 async def bulk_update_stats_delta(
343 self, ts: int, updates: Dict[str, Dict[str, Dict[str, Counter]]], stream_id: int
344 ) -> None:
322345 """Bulk update stats tables for a given stream_id and updates the stats
323346 incremental position.
324347
325348 Args:
326 ts (int): Current timestamp in ms
327 updates(dict[str, dict[str, dict[str, Counter]]]): The updates to
328 commit as a mapping stats_type -> stats_id -> field -> delta.
329 stream_id (int): Current position.
330
331 Returns:
332 Deferred
349 ts: Current timestamp in ms
350 updates: The updates to commit as a mapping of
351 stats_type -> stats_id -> field -> delta.
352 stream_id: Current position.
333353 """
334354
335355 def _bulk_update_stats_delta_txn(txn):
354374 updatevalues={"stream_id": stream_id},
355375 )
356376
357 return self.db_pool.runInteraction(
377 await self.db_pool.runInteraction(
358378 "bulk_update_stats_delta", _bulk_update_stats_delta_txn
359379 )
360380
361 def update_stats_delta(
381 async def update_stats_delta(
362382 self,
363 ts,
364 stats_type,
365 stats_id,
366 fields,
367 complete_with_stream_id,
368 absolute_field_overrides=None,
369 ):
383 ts: int,
384 stats_type: str,
385 stats_id: str,
386 fields: Dict[str, int],
387 complete_with_stream_id: Optional[int],
388 absolute_field_overrides: Optional[Dict[str, int]] = None,
389 ) -> None:
370390 """
371391 Updates the statistics for a subject, with a delta (difference/relative
372392 change).
373393
374394 Args:
375 ts (int): timestamp of the change
376 stats_type (str): "room" or "user" – the kind of subject
377 stats_id (str): the subject's ID (room ID or user ID)
378 fields (dict[str, int]): Deltas of stats values.
379 complete_with_stream_id (int, optional):
395 ts: timestamp of the change
396 stats_type: "room" or "user" – the kind of subject
397 stats_id: the subject's ID (room ID or user ID)
398 fields: Deltas of stats values.
399 complete_with_stream_id:
380400 If supplied, converts an incomplete row into a complete row,
381401 with the supplied stream_id marked as the stream_id where the
382402 row was completed.
383 absolute_field_overrides (dict[str, int]): Current stats values
384 (i.e. not deltas) of absolute fields.
385 Does not work with per-slice fields.
386 """
387
388 return self.db_pool.runInteraction(
403 absolute_field_overrides: Current stats values (i.e. not deltas) of
404 absolute fields. Does not work with per-slice fields.
405 """
406
407 await self.db_pool.runInteraction(
389408 "update_stats_delta",
390409 self._update_stats_delta_txn,
391410 ts,
645664 txn, into_table, all_dest_keyvalues, src_row
646665 )
647666
648 def get_changes_room_total_events_and_bytes(self, min_pos, max_pos):
667 async def get_changes_room_total_events_and_bytes(
668 self, min_pos: int, max_pos: int
669 ) -> Dict[str, Dict[str, int]]:
649670 """Fetches the counts of events in the given range of stream IDs.
650671
651672 Args:
652 min_pos (int)
653 max_pos (int)
673 min_pos
674 max_pos
654675
655676 Returns:
656 Deferred[dict[str, dict[str, int]]]: Mapping of room ID to field
657 changes.
658 """
659
660 return self.db_pool.runInteraction(
677 Mapping of room ID to field changes.
678 """
679
680 return await self.db_pool.runInteraction(
661681 "stats_incremental_total_events_and_bytes",
662682 self.get_changes_room_total_events_and_bytes_txn,
663683 min_pos,
3838 import abc
3939 import logging
4040 from collections import namedtuple
41 from typing import Optional
41 from typing import TYPE_CHECKING, Dict, List, Optional, Set, Tuple
4242
4343 from twisted.internet import defer
4444
45 from synapse.api.filtering import Filter
46 from synapse.events import EventBase
4547 from synapse.logging.context import make_deferred_yieldable, run_in_background
4648 from synapse.storage._base import SQLBaseStore
47 from synapse.storage.database import DatabasePool, make_in_list_sql_clause
49 from synapse.storage.database import (
50 DatabasePool,
51 LoggingTransaction,
52 make_in_list_sql_clause,
53 )
4854 from synapse.storage.databases.main.events_worker import EventsWorkerStore
49 from synapse.storage.engines import PostgresEngine
50 from synapse.types import RoomStreamToken
55 from synapse.storage.engines import BaseDatabaseEngine, PostgresEngine
56 from synapse.types import Collection, RoomStreamToken
5157 from synapse.util.caches.stream_change_cache import StreamChangeCache
58
59 if TYPE_CHECKING:
60 from synapse.server import HomeServer
5261
5362 logger = logging.getLogger(__name__)
5463
6776
6877
6978 def generate_pagination_where_clause(
70 direction, column_names, from_token, to_token, engine
71 ):
79 direction: str,
80 column_names: Tuple[str, str],
81 from_token: Optional[Tuple[int, int]],
82 to_token: Optional[Tuple[int, int]],
83 engine: BaseDatabaseEngine,
84 ) -> str:
7285 """Creates an SQL expression to bound the columns by the pagination
7386 tokens.
7487
89102 token, but include those that match the to token.
90103
91104 Args:
92 direction (str): Whether we're paginating backwards("b") or
93 forwards ("f").
94 column_names (tuple[str, str]): The column names to bound. Must *not*
95 be user defined as these get inserted directly into the SQL
96 statement without escapes.
97 from_token (tuple[int, int]|None): The start point for the pagination.
98 This is an exclusive minimum bound if direction is "f", and an
99 inclusive maximum bound if direction is "b".
100 to_token (tuple[int, int]|None): The endpoint point for the pagination.
101 This is an inclusive maximum bound if direction is "f", and an
102 exclusive minimum bound if direction is "b".
105 direction: Whether we're paginating backwards("b") or forwards ("f").
106 column_names: The column names to bound. Must *not* be user defined as
107 these get inserted directly into the SQL statement without escapes.
108 from_token: The start point for the pagination. This is an exclusive
109 minimum bound if direction is "f", and an inclusive maximum bound if
110 direction is "b".
111 to_token: The endpoint point for the pagination. This is an inclusive
112 maximum bound if direction is "f", and an exclusive minimum bound if
113 direction is "b".
103114 engine: The database engine to generate the clauses for
104115
105116 Returns:
106 str: The sql expression
117 The sql expression
107118 """
108119 assert direction in ("b", "f")
109120
131142 return " AND ".join(where_clause)
132143
133144
134 def _make_generic_sql_bound(bound, column_names, values, engine):
145 def _make_generic_sql_bound(
146 bound: str,
147 column_names: Tuple[str, str],
148 values: Tuple[Optional[int], int],
149 engine: BaseDatabaseEngine,
150 ) -> str:
135151 """Create an SQL expression that bounds the given column names by the
136152 values, e.g. create the equivalent of `(1, 2) < (col1, col2)`.
137153
141157 out manually.
142158
143159 Args:
144 bound (str): The comparison operator to use. One of ">", "<", ">=",
160 bound: The comparison operator to use. One of ">", "<", ">=",
145161 "<=", where the values are on the left and columns on the right.
146 names (tuple[str, str]): The column names. Must *not* be user defined
162 names: The column names. Must *not* be user defined
147163 as these get inserted directly into the SQL statement without
148164 escapes.
149 values (tuple[int|None, int]): The values to bound the columns by. If
165 values: The values to bound the columns by. If
150166 the first value is None then only creates a bound on the second
151167 column.
152168 engine: The database engine to generate the SQL for
153169
154170 Returns:
155 str
171 The SQL statement
156172 """
157173
158174 assert bound in (">", "<", ">=", "<=")
192208 )
193209
194210
195 def filter_to_clause(event_filter):
211 def filter_to_clause(event_filter: Optional[Filter]) -> Tuple[str, List[str]]:
196212 # NB: This may create SQL clauses that don't optimise well (and we don't
197213 # have indices on all possible clauses). E.g. it may create
198214 # "room_id == X AND room_id != X", which postgres doesn't optimise.
250266
251267 __metaclass__ = abc.ABCMeta
252268
253 def __init__(self, database: DatabasePool, db_conn, hs):
269 def __init__(self, database: DatabasePool, db_conn, hs: "HomeServer"):
254270 super(StreamWorkerStore, self).__init__(database, db_conn, hs)
255271
256272 self._instance_name = hs.get_instance_name()
283299 self._stream_order_on_start = self.get_room_max_stream_ordering()
284300
285301 @abc.abstractmethod
286 def get_room_max_stream_ordering(self):
302 def get_room_max_stream_ordering(self) -> int:
287303 raise NotImplementedError()
288304
289305 @abc.abstractmethod
290 def get_room_min_stream_ordering(self):
306 def get_room_min_stream_ordering(self) -> int:
291307 raise NotImplementedError()
292308
293 @defer.inlineCallbacks
294 def get_room_events_stream_for_rooms(
295 self, room_ids, from_key, to_key, limit=0, order="DESC"
296 ):
309 async def get_room_events_stream_for_rooms(
310 self,
311 room_ids: Collection[str],
312 from_key: str,
313 to_key: str,
314 limit: int = 0,
315 order: str = "DESC",
316 ) -> Dict[str, Tuple[List[EventBase], str]]:
297317 """Get new room events in stream ordering since `from_key`.
298318
299319 Args:
300 room_id (str)
301 from_key (str): Token from which no events are returned before
302 to_key (str): Token from which no events are returned after. (This
320 room_ids
321 from_key: Token from which no events are returned before
322 to_key: Token from which no events are returned after. (This
303323 is typically the current stream token)
304 limit (int): Maximum number of events to return
305 order (str): Either "DESC" or "ASC". Determines which events are
324 limit: Maximum number of events to return
325 order: Either "DESC" or "ASC". Determines which events are
306326 returned when the result is limited. If "DESC" then the most
307327 recent `limit` events are returned, otherwise returns the
308328 oldest `limit` events.
309329
310330 Returns:
311 Deferred[dict[str,tuple[list[FrozenEvent], str]]]
312 A map from room id to a tuple containing:
313 - list of recent events in the room
314 - stream ordering key for the start of the chunk of events returned.
331 A map from room id to a tuple containing:
332 - list of recent events in the room
333 - stream ordering key for the start of the chunk of events returned.
315334 """
316335 from_id = RoomStreamToken.parse_stream_token(from_key).stream
317336
318 room_ids = yield self._events_stream_cache.get_entities_changed(
319 room_ids, from_id
320 )
337 room_ids = self._events_stream_cache.get_entities_changed(room_ids, from_id)
321338
322339 if not room_ids:
323340 return {}
325342 results = {}
326343 room_ids = list(room_ids)
327344 for rm_ids in (room_ids[i : i + 20] for i in range(0, len(room_ids), 20)):
328 res = yield make_deferred_yieldable(
345 res = await make_deferred_yieldable(
329346 defer.gatherResults(
330347 [
331348 run_in_background(
345362
346363 return results
347364
348 def get_rooms_that_changed(self, room_ids, from_key):
365 def get_rooms_that_changed(
366 self, room_ids: Collection[str], from_key: str
367 ) -> Set[str]:
349368 """Given a list of rooms and a token, return rooms where there may have
350369 been changes.
351370
352371 Args:
353 room_ids (list)
354 from_key (str): The room_key portion of a StreamToken
355 """
356 from_key = RoomStreamToken.parse_stream_token(from_key).stream
372 room_ids
373 from_key: The room_key portion of a StreamToken
374 """
375 from_id = RoomStreamToken.parse_stream_token(from_key).stream
357376 return {
358377 room_id
359378 for room_id in room_ids
360 if self._events_stream_cache.has_entity_changed(room_id, from_key)
379 if self._events_stream_cache.has_entity_changed(room_id, from_id)
361380 }
362381
363 @defer.inlineCallbacks
364 def get_room_events_stream_for_room(
365 self, room_id, from_key, to_key, limit=0, order="DESC"
366 ):
367
382 async def get_room_events_stream_for_room(
383 self,
384 room_id: str,
385 from_key: str,
386 to_key: str,
387 limit: int = 0,
388 order: str = "DESC",
389 ) -> Tuple[List[EventBase], str]:
368390 """Get new room events in stream ordering since `from_key`.
369391
370392 Args:
371 room_id (str)
372 from_key (str): Token from which no events are returned before
373 to_key (str): Token from which no events are returned after. (This
393 room_id
394 from_key: Token from which no events are returned before
395 to_key: Token from which no events are returned after. (This
374396 is typically the current stream token)
375 limit (int): Maximum number of events to return
376 order (str): Either "DESC" or "ASC". Determines which events are
397 limit: Maximum number of events to return
398 order: Either "DESC" or "ASC". Determines which events are
377399 returned when the result is limited. If "DESC" then the most
378400 recent `limit` events are returned, otherwise returns the
379401 oldest `limit` events.
380402
381403 Returns:
382 Deferred[tuple[list[FrozenEvent], str]]: Returns the list of
383 events (in ascending order) and the token from the start of
384 the chunk of events returned.
404 The list of events (in ascending order) and the token from the start
405 of the chunk of events returned.
385406 """
386407 if from_key == to_key:
387408 return [], from_key
389410 from_id = RoomStreamToken.parse_stream_token(from_key).stream
390411 to_id = RoomStreamToken.parse_stream_token(to_key).stream
391412
392 has_changed = yield self._events_stream_cache.has_entity_changed(
393 room_id, from_id
394 )
413 has_changed = self._events_stream_cache.has_entity_changed(room_id, from_id)
395414
396415 if not has_changed:
397416 return [], from_key
409428 rows = [_EventDictReturn(row[0], None, row[1]) for row in txn]
410429 return rows
411430
412 rows = yield self.db_pool.runInteraction("get_room_events_stream_for_room", f)
413
414 ret = yield self.get_events_as_list(
431 rows = await self.db_pool.runInteraction("get_room_events_stream_for_room", f)
432
433 ret = await self.get_events_as_list(
415434 [r.event_id for r in rows], get_prev_content=True
416435 )
417436
429448
430449 return ret, key
431450
432 @defer.inlineCallbacks
433 def get_membership_changes_for_user(self, user_id, from_key, to_key):
451 async def get_membership_changes_for_user(
452 self, user_id: str, from_key: str, to_key: str
453 ) -> List[EventBase]:
434454 from_id = RoomStreamToken.parse_stream_token(from_key).stream
435455 to_id = RoomStreamToken.parse_stream_token(to_key).stream
436456
459479
460480 return rows
461481
462 rows = yield self.db_pool.runInteraction("get_membership_changes_for_user", f)
463
464 ret = yield self.get_events_as_list(
482 rows = await self.db_pool.runInteraction("get_membership_changes_for_user", f)
483
484 ret = await self.get_events_as_list(
465485 [r.event_id for r in rows], get_prev_content=True
466486 )
467487
469489
470490 return ret
471491
472 @defer.inlineCallbacks
473 def get_recent_events_for_room(self, room_id, limit, end_token):
492 async def get_recent_events_for_room(
493 self, room_id: str, limit: int, end_token: str
494 ) -> Tuple[List[EventBase], str]:
474495 """Get the most recent events in the room in topological ordering.
475496
476497 Args:
477 room_id (str)
478 limit (int)
479 end_token (str): The stream token representing now.
498 room_id
499 limit
500 end_token: The stream token representing now.
480501
481502 Returns:
482 Deferred[tuple[list[FrozenEvent], str]]: Returns a list of
483 events and a token pointing to the start of the returned
484 events.
485 The events returned are in ascending order.
486 """
487
488 rows, token = yield self.get_recent_event_ids_for_room(
503 A list of events and a token pointing to the start of the returned
504 events. The events returned are in ascending order.
505 """
506
507 rows, token = await self.get_recent_event_ids_for_room(
489508 room_id, limit, end_token
490509 )
491510
492 events = yield self.get_events_as_list(
511 events = await self.get_events_as_list(
493512 [r.event_id for r in rows], get_prev_content=True
494513 )
495514
497516
498517 return (events, token)
499518
500 @defer.inlineCallbacks
501 def get_recent_event_ids_for_room(self, room_id, limit, end_token):
519 async def get_recent_event_ids_for_room(
520 self, room_id: str, limit: int, end_token: str
521 ) -> Tuple[List[_EventDictReturn], str]:
502522 """Get the most recent events in the room in topological ordering.
503523
504524 Args:
505 room_id (str)
506 limit (int)
507 end_token (str): The stream token representing now.
525 room_id
526 limit
527 end_token: The stream token representing now.
508528
509529 Returns:
510 Deferred[tuple[list[_EventDictReturn], str]]: Returns a list of
511 _EventDictReturn and a token pointing to the start of the returned
512 events.
513 The events returned are in ascending order.
530 A list of _EventDictReturn and a token pointing to the start of the
531 returned events. The events returned are in ascending order.
514532 """
515533 # Allow a zero limit here, and no-op.
516534 if limit == 0:
518536
519537 end_token = RoomStreamToken.parse(end_token)
520538
521 rows, token = yield self.db_pool.runInteraction(
539 rows, token = await self.db_pool.runInteraction(
522540 "get_recent_event_ids_for_room",
523541 self._paginate_room_events_txn,
524542 room_id,
531549
532550 return rows, token
533551
534 def get_room_event_before_stream_ordering(self, room_id, stream_ordering):
552 async def get_room_event_before_stream_ordering(
553 self, room_id: str, stream_ordering: int
554 ) -> Tuple[int, int, str]:
535555 """Gets details of the first event in a room at or before a stream ordering
536556
537557 Args:
538 room_id (str):
539 stream_ordering (int):
558 room_id:
559 stream_ordering:
540560
541561 Returns:
542 Deferred[(int, int, str)]:
543 (stream ordering, topological ordering, event_id)
562 A tuple of (stream ordering, topological ordering, event_id)
544563 """
545564
546565 def _f(txn):
555574 txn.execute(sql, (room_id, stream_ordering))
556575 return txn.fetchone()
557576
558 return self.db_pool.runInteraction("get_room_event_before_stream_ordering", _f)
577 return await self.db_pool.runInteraction(
578 "get_room_event_before_stream_ordering", _f
579 )
559580
560581 async def get_room_events_max_id(self, room_id: Optional[str] = None) -> str:
561582 """Returns the current token for rooms stream.
573594 )
574595 return "t%d-%d" % (topo, token)
575596
576 def get_stream_token_for_event(self, event_id):
577 """The stream token for an event
578 Args:
579 event_id(str): The id of the event to look up a stream token for.
597 async def get_stream_id_for_event(self, event_id: str) -> int:
598 """The stream ID for an event
599 Args:
600 event_id: The id of the event to look up a stream token for.
580601 Raises:
581602 StoreError if the event wasn't in the database.
582603 Returns:
583 A deferred "s%d" stream token.
584 """
585 return self.db_pool.simple_select_one_onecol(
586 table="events", keyvalues={"event_id": event_id}, retcol="stream_ordering"
587 ).addCallback(lambda row: "s%d" % (row,))
588
589 def get_topological_token_for_event(self, event_id):
604 A stream ID.
605 """
606 return await self.db_pool.runInteraction(
607 "get_stream_id_for_event", self.get_stream_id_for_event_txn, event_id,
608 )
609
610 def get_stream_id_for_event_txn(
611 self, txn: LoggingTransaction, event_id: str, allow_none=False,
612 ) -> int:
613 return self.db_pool.simple_select_one_onecol_txn(
614 txn=txn,
615 table="events",
616 keyvalues={"event_id": event_id},
617 retcol="stream_ordering",
618 allow_none=allow_none,
619 )
620
621 async def get_stream_token_for_event(self, event_id: str) -> str:
590622 """The stream token for an event
591623 Args:
592 event_id(str): The id of the event to look up a stream token for.
624 event_id: The id of the event to look up a stream token for.
593625 Raises:
594626 StoreError if the event wasn't in the database.
595627 Returns:
596 A deferred "t%d-%d" topological token.
597 """
598 return self.db_pool.simple_select_one(
628 A "s%d" stream token.
629 """
630 stream_id = await self.get_stream_id_for_event(event_id)
631 return "s%d" % (stream_id,)
632
633 async def get_topological_token_for_event(self, event_id: str) -> str:
634 """The stream token for an event
635 Args:
636 event_id: The id of the event to look up a stream token for.
637 Raises:
638 StoreError if the event wasn't in the database.
639 Returns:
640 A "t%d-%d" topological token.
641 """
642 row = await self.db_pool.simple_select_one(
599643 table="events",
600644 keyvalues={"event_id": event_id},
601645 retcols=("stream_ordering", "topological_ordering"),
602646 desc="get_topological_token_for_event",
603 ).addCallback(
604 lambda row: "t%d-%d" % (row["topological_ordering"], row["stream_ordering"])
605 )
606
607 def get_max_topological_token(self, room_id, stream_key):
608 """Get the max topological token in a room before the given stream
647 )
648 return "t%d-%d" % (row["topological_ordering"], row["stream_ordering"])
649
650 async def get_current_topological_token(self, room_id: str, stream_key: int) -> int:
651 """Gets the topological token in a room after or at the given stream
609652 ordering.
610653
611654 Args:
612 room_id (str)
613 stream_key (int)
614
615 Returns:
616 Deferred[int]
655 room_id
656 stream_key
617657 """
618658 sql = (
619 "SELECT coalesce(max(topological_ordering), 0) FROM events"
620 " WHERE room_id = ? AND stream_ordering < ?"
621 )
622 return self.db_pool.execute(
623 "get_max_topological_token", None, sql, room_id, stream_key
624 ).addCallback(lambda r: r[0][0] if r else 0)
625
626 def _get_max_topological_txn(self, txn, room_id):
659 "SELECT coalesce(MIN(topological_ordering), 0) FROM events"
660 " WHERE room_id = ? AND stream_ordering >= ?"
661 )
662 row = await self.db_pool.execute(
663 "get_current_topological_token", None, sql, room_id, stream_key
664 )
665 return row[0][0] if row else 0
666
667 def _get_max_topological_txn(self, txn: LoggingTransaction, room_id: str) -> int:
627668 txn.execute(
628669 "SELECT MAX(topological_ordering) FROM events WHERE room_id = ?",
629670 (room_id,),
633674 return rows[0][0] if rows else 0
634675
635676 @staticmethod
636 def _set_before_and_after(events, rows, topo_order=True):
677 def _set_before_and_after(
678 events: List[EventBase], rows: List[_EventDictReturn], topo_order: bool = True
679 ):
637680 """Inserts ordering information to events' internal metadata from
638681 the DB rows.
639682
640683 Args:
641 events (list[FrozenEvent])
642 rows (list[_EventDictReturn])
643 topo_order (bool): Whether the events were ordered topologically
644 or by stream ordering. If true then all rows should have a non
645 null topological_ordering.
684 events
685 rows
686 topo_order: Whether the events were ordered topologically or by stream
687 ordering. If true then all rows should have a non null
688 topological_ordering.
646689 """
647690 for event, row in zip(events, rows):
648691 stream = row.stream_ordering
655698 internal.after = str(RoomStreamToken(topo, stream))
656699 internal.order = (int(topo) if topo else 0, int(stream))
657700
658 @defer.inlineCallbacks
659 def get_events_around(
660 self, room_id, event_id, before_limit, after_limit, event_filter=None
661 ):
701 async def get_events_around(
702 self,
703 room_id: str,
704 event_id: str,
705 before_limit: int,
706 after_limit: int,
707 event_filter: Optional[Filter] = None,
708 ) -> dict:
662709 """Retrieve events and pagination tokens around a given event in a
663710 room.
664
665 Args:
666 room_id (str)
667 event_id (str)
668 before_limit (int)
669 after_limit (int)
670 event_filter (Filter|None)
671
672 Returns:
673 dict
674 """
675
676 results = yield self.db_pool.runInteraction(
711 """
712
713 results = await self.db_pool.runInteraction(
677714 "get_events_around",
678715 self._get_events_around_txn,
679716 room_id,
683720 event_filter,
684721 )
685722
686 events_before = yield self.get_events_as_list(
723 events_before = await self.get_events_as_list(
687724 list(results["before"]["event_ids"]), get_prev_content=True
688725 )
689726
690 events_after = yield self.get_events_as_list(
727 events_after = await self.get_events_as_list(
691728 list(results["after"]["event_ids"]), get_prev_content=True
692729 )
693730
699736 }
700737
701738 def _get_events_around_txn(
702 self, txn, room_id, event_id, before_limit, after_limit, event_filter
703 ):
739 self,
740 txn: LoggingTransaction,
741 room_id: str,
742 event_id: str,
743 before_limit: int,
744 after_limit: int,
745 event_filter: Optional[Filter],
746 ) -> dict:
704747 """Retrieves event_ids and pagination tokens around a given event in a
705748 room.
706749
707750 Args:
708 room_id (str)
709 event_id (str)
710 before_limit (int)
711 after_limit (int)
712 event_filter (Filter|None)
751 room_id
752 event_id
753 before_limit
754 after_limit
755 event_filter
713756
714757 Returns:
715758 dict
721764 keyvalues={"event_id": event_id, "room_id": room_id},
722765 retcols=["stream_ordering", "topological_ordering"],
723766 )
767
768 # This cannot happen as `allow_none=False`.
769 assert results is not None
724770
725771 # Paginating backwards includes the event at the token, but paginating
726772 # forward doesn't.
757803 "after": {"event_ids": events_after, "token": end_token},
758804 }
759805
760 @defer.inlineCallbacks
761 def get_all_new_events_stream(self, from_id, current_id, limit):
806 async def get_all_new_events_stream(
807 self, from_id: int, current_id: int, limit: int
808 ) -> Tuple[int, List[EventBase]]:
762809 """Get all new events
763810
764811 Returns all events with from_id < stream_ordering <= current_id.
765812
766813 Args:
767 from_id (int): the stream_ordering of the last event we processed
768 current_id (int): the stream_ordering of the most recently processed event
769 limit (int): the maximum number of events to return
814 from_id: the stream_ordering of the last event we processed
815 current_id: the stream_ordering of the most recently processed event
816 limit: the maximum number of events to return
770817
771818 Returns:
772 Deferred[Tuple[int, list[FrozenEvent]]]: A tuple of (next_id, events), where
773 `next_id` is the next value to pass as `from_id` (it will either be the
774 stream_ordering of the last returned event, or, if fewer than `limit` events
775 were found, `current_id`.
819 A tuple of (next_id, events), where `next_id` is the next value to
820 pass as `from_id` (it will either be the stream_ordering of the
821 last returned event, or, if fewer than `limit` events were found,
822 the `current_id`).
776823 """
777824
778825 def get_all_new_events_stream_txn(txn):
794841
795842 return upper_bound, [row[1] for row in rows]
796843
797 upper_bound, event_ids = yield self.db_pool.runInteraction(
844 upper_bound, event_ids = await self.db_pool.runInteraction(
798845 "get_all_new_events_stream", get_all_new_events_stream_txn
799846 )
800847
801 events = yield self.get_events_as_list(event_ids)
848 events = await self.get_events_as_list(event_ids)
802849
803850 return upper_bound, events
804851
816863 desc="get_federation_out_pos",
817864 )
818865
819 async def update_federation_out_pos(self, typ, stream_id):
866 async def update_federation_out_pos(self, typ: str, stream_id: int) -> None:
820867 if self._need_to_reset_federation_stream_positions:
821868 await self.db_pool.runInteraction(
822869 "_reset_federation_positions_txn", self._reset_federation_positions_txn
823870 )
824871 self._need_to_reset_federation_stream_positions = False
825872
826 return await self.db_pool.simple_update_one(
873 await self.db_pool.simple_update_one(
827874 table="federation_stream_position",
828875 keyvalues={"type": typ, "instance_name": self._instance_name},
829876 updatevalues={"stream_id": stream_id},
830877 desc="update_federation_out_pos",
831878 )
832879
833 def _reset_federation_positions_txn(self, txn):
880 def _reset_federation_positions_txn(self, txn: LoggingTransaction) -> None:
834881 """Fiddles with the `federation_stream_position` table to make it match
835882 the configured federation sender instances during start up.
836883 """
869916 GROUP BY type
870917 """
871918 txn.execute(sql)
872 min_positions = dict(txn) # Map from type -> min position
919 min_positions = {typ: pos for typ, pos in txn} # Map from type -> min position
873920
874921 # Ensure we do actually have some values here
875922 assert set(min_positions) == {"federation", "events"}
891938 values={"stream_id": stream_id},
892939 )
893940
894 def has_room_changed_since(self, room_id, stream_id):
941 def has_room_changed_since(self, room_id: str, stream_id: int) -> bool:
895942 return self._events_stream_cache.has_entity_changed(room_id, stream_id)
896943
897944 def _paginate_room_events_txn(
898945 self,
899 txn,
900 room_id,
901 from_token,
902 to_token=None,
903 direction="b",
904 limit=-1,
905 event_filter=None,
906 ):
946 txn: LoggingTransaction,
947 room_id: str,
948 from_token: RoomStreamToken,
949 to_token: Optional[RoomStreamToken] = None,
950 direction: str = "b",
951 limit: int = -1,
952 event_filter: Optional[Filter] = None,
953 ) -> Tuple[List[_EventDictReturn], str]:
907954 """Returns list of events before or after a given token.
908955
909956 Args:
910957 txn
911 room_id (str)
912 from_token (RoomStreamToken): The token used to stream from
913 to_token (RoomStreamToken|None): A token which if given limits the
914 results to only those before
915 direction(char): Either 'b' or 'f' to indicate whether we are
916 paginating forwards or backwards from `from_key`.
917 limit (int): The maximum number of events to return.
918 event_filter (Filter|None): If provided filters the events to
958 room_id
959 from_token: The token used to stream from
960 to_token: A token which if given limits the results to only those before
961 direction: Either 'b' or 'f' to indicate whether we are paginating
962 forwards or backwards from `from_key`.
963 limit: The maximum number of events to return.
964 event_filter: If provided filters the events to
919965 those that match the filter.
920966
921967 Returns:
922 Deferred[tuple[list[_EventDictReturn], str]]: Returns the results
923 as a list of _EventDictReturn and a token that points to the end
924 of the result set. If no events are returned then the end of the
925 stream has been reached (i.e. there are no events between
926 `from_token` and `to_token`), or `limit` is zero.
968 A list of _EventDictReturn and a token that points to the end of the
969 result set. If no events are returned then the end of the stream has
970 been reached (i.e. there are no events between `from_token` and
971 `to_token`), or `limit` is zero.
927972 """
928973
929974 assert int(limit) >= 0
10071052
10081053 return rows, str(next_token)
10091054
1010 @defer.inlineCallbacks
1011 def paginate_room_events(
1012 self, room_id, from_key, to_key=None, direction="b", limit=-1, event_filter=None
1013 ):
1055 async def paginate_room_events(
1056 self,
1057 room_id: str,
1058 from_key: str,
1059 to_key: Optional[str] = None,
1060 direction: str = "b",
1061 limit: int = -1,
1062 event_filter: Optional[Filter] = None,
1063 ) -> Tuple[List[EventBase], str]:
10141064 """Returns list of events before or after a given token.
10151065
10161066 Args:
1017 room_id (str)
1018 from_key (str): The token used to stream from
1019 to_key (str|None): A token which if given limits the results to
1020 only those before
1021 direction(char): Either 'b' or 'f' to indicate whether we are
1022 paginating forwards or backwards from `from_key`.
1023 limit (int): The maximum number of events to return.
1024 event_filter (Filter|None): If provided filters the events to
1025 those that match the filter.
1067 room_id
1068 from_key: The token used to stream from
1069 to_key: A token which if given limits the results to only those before
1070 direction: Either 'b' or 'f' to indicate whether we are paginating
1071 forwards or backwards from `from_key`.
1072 limit: The maximum number of events to return.
1073 event_filter: If provided filters the events to those that match the filter.
10261074
10271075 Returns:
1028 tuple[list[FrozenEvent], str]: Returns the results as a list of
1029 events and a token that points to the end of the result set. If no
1030 events are returned then the end of the stream has been reached
1031 (i.e. there are no events between `from_key` and `to_key`).
1076 The results as a list of events and a token that points to the end
1077 of the result set. If no events are returned then the end of the
1078 stream has been reached (i.e. there are no events between `from_key`
1079 and `to_key`).
10321080 """
10331081
10341082 from_key = RoomStreamToken.parse(from_key)
10351083 if to_key:
10361084 to_key = RoomStreamToken.parse(to_key)
10371085
1038 rows, token = yield self.db_pool.runInteraction(
1086 rows, token = await self.db_pool.runInteraction(
10391087 "paginate_room_events",
10401088 self._paginate_room_events_txn,
10411089 room_id,
10461094 event_filter,
10471095 )
10481096
1049 events = yield self.get_events_as_list(
1097 events = await self.get_events_as_list(
10501098 [r.event_id for r in rows], get_prev_content=True
10511099 )
10521100
10561104
10571105
10581106 class StreamStore(StreamWorkerStore):
1059 def get_room_max_stream_ordering(self):
1107 def get_room_max_stream_ordering(self) -> int:
10601108 return self._stream_id_gen.get_current_token()
10611109
1062 def get_room_min_stream_ordering(self):
1110 def get_room_min_stream_ordering(self) -> int:
10631111 return self._backfill_id_gen.get_current_token()
1616 import logging
1717 from typing import Dict, List, Tuple
1818
19 from canonicaljson import json
20
2119 from synapse.storage._base import db_to_json
2220 from synapse.storage.databases.main.account_data import AccountDataWorkerStore
2321 from synapse.types import JsonDict
22 from synapse.util import json_encoder
2423 from synapse.util.caches.descriptors import cached
2524
2625 logger = logging.getLogger(__name__)
4342 "room_tags", {"user_id": user_id}, ["room_id", "tag", "content"]
4443 )
4544
46 tags_by_room = {}
45 tags_by_room = {} # type: Dict[str, Dict[str, JsonDict]]
4746 for row in rows:
4847 room_tags = tags_by_room.setdefault(row["room_id"], {})
4948 room_tags[row["tag"]] = db_to_json(row["content"])
9796 txn.execute(sql, (user_id, room_id))
9897 tags = []
9998 for tag, content in txn:
100 tags.append(json.dumps(tag) + ":" + content)
99 tags.append(json_encoder.encode(tag) + ":" + content)
101100 tag_json = "{" + ",".join(tags) + "}"
102101 results.append((stream_id, (user_id, room_id, tag_json)))
103102
123122
124123 async def get_updated_tags(
125124 self, user_id: str, stream_id: int
126 ) -> Dict[str, List[str]]:
125 ) -> Dict[str, Dict[str, JsonDict]]:
127126 """Get all the tags for the rooms where the tags have changed since the
128127 given version
129128
199198 Returns:
200199 The next account data ID.
201200 """
202 content_json = json.dumps(content)
201 content_json = json_encoder.encode(content)
203202
204203 def add_tag_txn(txn, next_id):
205204 self.db_pool.simple_upsert_txn(
210209 )
211210 self._update_revision_txn(txn, user_id, room_id, next_id)
212211
213 with self._account_data_id_gen.get_next() as next_id:
212 with await self._account_data_id_gen.get_next() as next_id:
214213 await self.db_pool.runInteraction("add_tag", add_tag_txn, next_id)
215214
216215 self.get_tags_for_user.invalidate((user_id,))
232231 txn.execute(sql, (user_id, room_id, tag))
233232 self._update_revision_txn(txn, user_id, room_id, next_id)
234233
235 with self._account_data_id_gen.get_next() as next_id:
234 with await self._account_data_id_gen.get_next() as next_id:
236235 await self.db_pool.runInteraction("remove_tag", remove_tag_txn, next_id)
237236
238237 self.get_tags_for_user.invalidate((user_id,))
1414
1515 import logging
1616 from collections import namedtuple
17 from typing import Optional, Tuple
1718
1819 from canonicaljson import encode_canonical_json
1920
2021 from synapse.metrics.background_process_metrics import run_as_background_process
2122 from synapse.storage._base import SQLBaseStore, db_to_json
2223 from synapse.storage.database import DatabasePool
24 from synapse.types import JsonDict
2325 from synapse.util.caches.expiringcache import ExpiringCache
2426
2527 db_binary_type = memoryview
5456 expiry_ms=5 * 60 * 1000,
5557 )
5658
57 def get_received_txn_response(self, transaction_id, origin):
59 async def get_received_txn_response(
60 self, transaction_id: str, origin: str
61 ) -> Optional[Tuple[int, JsonDict]]:
5862 """For an incoming transaction from a given origin, check if we have
5963 already responded to it. If so, return the response code and response
6064 body (as a dict).
6165
6266 Args:
63 transaction_id (str)
64 origin(str)
67 transaction_id
68 origin
6569
6670 Returns:
67 tuple: None if we have not previously responded to
68 this transaction or a 2-tuple of (int, dict)
69 """
70
71 return self.db_pool.runInteraction(
71 None if we have not previously responded to this transaction or a
72 2-tuple of (int, dict)
73 """
74
75 return await self.db_pool.runInteraction(
7276 "get_received_txn_response",
7377 self._get_received_txn_response,
7478 transaction_id,
97101 else:
98102 return None
99103
100 def set_received_txn_response(self, transaction_id, origin, code, response_dict):
101 """Persist the response we returened for an incoming transaction, and
104 async def set_received_txn_response(
105 self, transaction_id: str, origin: str, code: int, response_dict: JsonDict
106 ) -> None:
107 """Persist the response we returned for an incoming transaction, and
102108 should return for subsequent transactions with the same transaction_id
103109 and origin.
104110
105111 Args:
106 txn
107 transaction_id (str)
108 origin (str)
109 code (int)
110 response_json (str)
111 """
112
113 return self.db_pool.simple_insert(
112 transaction_id: The incoming transaction ID.
113 origin: The origin server.
114 code: The response code.
115 response_dict: The response, to be encoded into JSON.
116 """
117
118 await self.db_pool.simple_insert(
114119 table="received_transactions",
115120 values={
116121 "transaction_id": transaction_id,
163168 else:
164169 return None
165170
166 def set_destination_retry_timings(
167 self, destination, failure_ts, retry_last_ts, retry_interval
168 ):
171 async def set_destination_retry_timings(
172 self,
173 destination: str,
174 failure_ts: Optional[int],
175 retry_last_ts: int,
176 retry_interval: int,
177 ) -> None:
169178 """Sets the current retry timings for a given destination.
170179 Both timings should be zero if retrying is no longer occuring.
171180
172181 Args:
173 destination (str)
174 failure_ts (int|None) - when the server started failing (ms since epoch)
175 retry_last_ts (int) - time of last retry attempt in unix epoch ms
176 retry_interval (int) - how long until next retry in ms
182 destination
183 failure_ts: when the server started failing (ms since epoch)
184 retry_last_ts: time of last retry attempt in unix epoch ms
185 retry_interval: how long until next retry in ms
177186 """
178187
179188 self._destination_retry_cache.pop(destination, None)
180 return self.db_pool.runInteraction(
189 return await self.db_pool.runInteraction(
181190 "set_destination_retry_timings",
182191 self._set_destination_retry_timings,
183192 destination,
253262 "cleanup_transactions", self._cleanup_transactions
254263 )
255264
256 def _cleanup_transactions(self):
265 async def _cleanup_transactions(self) -> None:
257266 now = self._clock.time_msec()
258267 month_ago = now - 30 * 24 * 60 * 60 * 1000
259268
260269 def _cleanup_transactions_txn(txn):
261270 txn.execute("DELETE FROM received_transactions WHERE ts < ?", (month_ago,))
262271
263 return self.db_pool.runInteraction(
272 await self.db_pool.runInteraction(
264273 "_cleanup_transactions", _cleanup_transactions_txn
265274 )
1111 # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
1212 # See the License for the specific language governing permissions and
1313 # limitations under the License.
14 from typing import Any, Dict, Optional, Union
14 from typing import Any, Dict, List, Optional, Tuple, Union
1515
1616 import attr
17 from canonicaljson import json
1817
1918 from synapse.api.errors import StoreError
2019 from synapse.storage._base import SQLBaseStore, db_to_json
20 from synapse.storage.database import LoggingTransaction
2121 from synapse.types import JsonDict
22 from synapse.util import stringutils as stringutils
22 from synapse.util import json_encoder, stringutils
2323
2424
2525 @attr.s
7171 StoreError if a unique session ID cannot be generated.
7272 """
7373 # The clientdict gets stored as JSON.
74 clientdict_json = json.dumps(clientdict)
74 clientdict_json = json_encoder.encode(clientdict)
7575
7676 # autogen a session ID and try to create it. We may clash, so just
7777 # try a few times till one goes through, giving up eventually.
142142 await self.db_pool.simple_upsert(
143143 table="ui_auth_sessions_credentials",
144144 keyvalues={"session_id": session_id, "stage_type": stage_type},
145 values={"result": json.dumps(result)},
145 values={"result": json_encoder.encode(result)},
146146 desc="mark_ui_auth_stage_complete",
147147 )
148148 except self.db_pool.engine.module.IntegrityError:
183183 The dictionary from the client root level, not the 'auth' key.
184184 """
185185 # The clientdict gets stored as JSON.
186 clientdict_json = json.dumps(clientdict)
186 clientdict_json = json_encoder.encode(clientdict)
187187
188188 await self.db_pool.simple_update_one(
189189 table="ui_auth_sessions",
213213 value,
214214 )
215215
216 def _set_ui_auth_session_data_txn(self, txn, session_id: str, key: str, value: Any):
216 def _set_ui_auth_session_data_txn(
217 self, txn: LoggingTransaction, session_id: str, key: str, value: Any
218 ):
217219 # Get the current value.
218220 result = self.db_pool.simple_select_one_txn(
219221 txn,
220222 table="ui_auth_sessions",
221223 keyvalues={"session_id": session_id},
222224 retcols=("serverdict",),
223 )
225 ) # type: Dict[str, Any] # type: ignore
224226
225227 # Update it and add it back to the database.
226228 serverdict = db_to_json(result["serverdict"])
230232 txn,
231233 table="ui_auth_sessions",
232234 keyvalues={"session_id": session_id},
233 updatevalues={"serverdict": json.dumps(serverdict)},
235 updatevalues={"serverdict": json_encoder.encode(serverdict)},
234236 )
235237
236238 async def get_ui_auth_session_data(
257259
258260 return serverdict.get(key, default)
259261
262 async def add_user_agent_ip_to_ui_auth_session(
263 self, session_id: str, user_agent: str, ip: str,
264 ):
265 """Add the given user agent / IP to the tracking table
266 """
267 await self.db_pool.simple_upsert(
268 table="ui_auth_sessions_ips",
269 keyvalues={"session_id": session_id, "user_agent": user_agent, "ip": ip},
270 values={},
271 desc="add_user_agent_ip_to_ui_auth_session",
272 )
273
274 async def get_user_agents_ips_to_ui_auth_session(
275 self, session_id: str,
276 ) -> List[Tuple[str, str]]:
277 """Get the given user agents / IPs used during the ui auth process
278
279 Returns:
280 List of user_agent/ip pairs
281 """
282 rows = await self.db_pool.simple_select_list(
283 table="ui_auth_sessions_ips",
284 keyvalues={"session_id": session_id},
285 retcols=("user_agent", "ip"),
286 desc="get_user_agents_ips_to_ui_auth_session",
287 )
288 return [(row["user_agent"], row["ip"]) for row in rows]
289
260290
261291 class UIAuthStore(UIAuthWorkerStore):
262 def delete_old_ui_auth_sessions(self, expiration_time: int):
292 async def delete_old_ui_auth_sessions(self, expiration_time: int) -> None:
263293 """
264294 Remove sessions which were last used earlier than the expiration time.
265295
268298 This is an epoch time in milliseconds.
269299
270300 """
271 return self.db_pool.runInteraction(
301 await self.db_pool.runInteraction(
272302 "delete_old_ui_auth_sessions",
273303 self._delete_old_ui_auth_sessions_txn,
274304 expiration_time,
275305 )
276306
277 def _delete_old_ui_auth_sessions_txn(self, txn, expiration_time: int):
307 def _delete_old_ui_auth_sessions_txn(
308 self, txn: LoggingTransaction, expiration_time: int
309 ):
278310 # Get the expired sessions.
279311 sql = "SELECT session_id FROM ui_auth_sessions WHERE creation_time <= ?"
280312 txn.execute(sql, [expiration_time])
281313 session_ids = [r[0] for r in txn.fetchall()]
282314
315 # Delete the corresponding IP/user agents.
316 self.db_pool.simple_delete_many_txn(
317 txn,
318 table="ui_auth_sessions_ips",
319 column="session_id",
320 iterable=session_ids,
321 keyvalues={},
322 )
323
283324 # Delete the corresponding completed credentials.
284325 self.db_pool.simple_delete_many_txn(
285326 txn,
1414
1515 import logging
1616 import re
17 from typing import Any, Dict, Iterable, Optional, Set, Tuple
1718
1819 from synapse.api.constants import EventTypes, JoinRules
1920 from synapse.storage.database import DatabasePool
363364
364365 return False
365366
366 def update_profile_in_user_dir(self, user_id, display_name, avatar_url):
367 async def update_profile_in_user_dir(
368 self, user_id: str, display_name: str, avatar_url: str
369 ) -> None:
367370 """
368371 Update or add a user's profile in the user directory.
369372 """
373 # If the display name or avatar URL are unexpected types, overwrite them.
374 if not isinstance(display_name, str):
375 display_name = None
376 if not isinstance(avatar_url, str):
377 avatar_url = None
370378
371379 def _update_profile_in_user_dir_txn(txn):
372380 new_entry = self.db_pool.simple_upsert_txn(
456464
457465 txn.call_after(self.get_user_in_directory.invalidate, (user_id,))
458466
459 return self.db_pool.runInteraction(
467 await self.db_pool.runInteraction(
460468 "update_profile_in_user_dir", _update_profile_in_user_dir_txn
461469 )
462470
463 def add_users_who_share_private_room(self, room_id, user_id_tuples):
471 async def add_users_who_share_private_room(
472 self, room_id: str, user_id_tuples: Iterable[Tuple[str, str]]
473 ) -> None:
464474 """Insert entries into the users_who_share_private_rooms table. The first
465475 user should be a local user.
466476
467477 Args:
468 room_id (str)
469 user_id_tuples([(str, str)]): iterable of 2-tuple of user IDs.
478 room_id
479 user_id_tuples: iterable of 2-tuple of user IDs.
470480 """
471481
472482 def _add_users_who_share_room_txn(txn):
482492 value_values=None,
483493 )
484494
485 return self.db_pool.runInteraction(
495 await self.db_pool.runInteraction(
486496 "add_users_who_share_room", _add_users_who_share_room_txn
487497 )
488498
489 def add_users_in_public_rooms(self, room_id, user_ids):
499 async def add_users_in_public_rooms(
500 self, room_id: str, user_ids: Iterable[str]
501 ) -> None:
490502 """Insert entries into the users_who_share_private_rooms table. The first
491503 user should be a local user.
492504
493505 Args:
494 room_id (str)
495 user_ids (list[str])
506 room_id
507 user_ids
496508 """
497509
498510 def _add_users_in_public_rooms_txn(txn):
506518 value_values=None,
507519 )
508520
509 return self.db_pool.runInteraction(
521 await self.db_pool.runInteraction(
510522 "add_users_in_public_rooms", _add_users_in_public_rooms_txn
511523 )
512524
513 def delete_all_from_user_dir(self):
525 async def delete_all_from_user_dir(self) -> None:
514526 """Delete the entire user directory
515527 """
516528
521533 txn.execute("DELETE FROM users_who_share_private_rooms")
522534 txn.call_after(self.get_user_in_directory.invalidate_all)
523535
524 return self.db_pool.runInteraction(
536 await self.db_pool.runInteraction(
525537 "delete_all_from_user_dir", _delete_all_from_user_dir_txn
526538 )
527539
528540 @cached()
529 def get_user_in_directory(self, user_id):
530 return self.db_pool.simple_select_one(
541 async def get_user_in_directory(self, user_id: str) -> Optional[Dict[str, Any]]:
542 return await self.db_pool.simple_select_one(
531543 table="user_directory",
532544 keyvalues={"user_id": user_id},
533545 retcols=("display_name", "avatar_url"),
535547 desc="get_user_in_directory",
536548 )
537549
538 def update_user_directory_stream_pos(self, stream_id):
539 return self.db_pool.simple_update_one(
550 async def update_user_directory_stream_pos(self, stream_id: str) -> None:
551 await self.db_pool.simple_update_one(
540552 table="user_directory_stream_pos",
541553 keyvalues={},
542554 updatevalues={"stream_id": stream_id},
553565 def __init__(self, database: DatabasePool, db_conn, hs):
554566 super(UserDirectoryStore, self).__init__(database, db_conn, hs)
555567
556 def remove_from_user_dir(self, user_id):
568 async def remove_from_user_dir(self, user_id: str) -> None:
557569 def _remove_from_user_dir_txn(txn):
558570 self.db_pool.simple_delete_txn(
559571 txn, table="user_directory", keyvalues={"user_id": user_id}
576588 )
577589 txn.call_after(self.get_user_in_directory.invalidate, (user_id,))
578590
579 return self.db_pool.runInteraction(
591 await self.db_pool.runInteraction(
580592 "remove_from_user_dir", _remove_from_user_dir_txn
581593 )
582594
603615
604616 return user_ids
605617
606 def remove_user_who_share_room(self, user_id, room_id):
618 async def remove_user_who_share_room(self, user_id: str, room_id: str) -> None:
607619 """
608620 Deletes entries in the users_who_share_*_rooms table. The first
609621 user should be a local user.
610622
611623 Args:
612 user_id (str)
613 room_id (str)
624 user_id
625 room_id
614626 """
615627
616628 def _remove_user_who_share_room_txn(txn):
630642 keyvalues={"user_id": user_id, "room_id": room_id},
631643 )
632644
633 return self.db_pool.runInteraction(
645 await self.db_pool.runInteraction(
634646 "remove_user_who_share_room", _remove_user_who_share_room_txn
635647 )
636648
662674 users.update(rows)
663675 return list(users)
664676
665 def get_user_directory_stream_pos(self):
666 return self.db_pool.simple_select_one_onecol(
677 @cached()
678 async def get_shared_rooms_for_users(
679 self, user_id: str, other_user_id: str
680 ) -> Set[str]:
681 """
682 Returns the rooms that a local user shares with another local or remote user.
683
684 Args:
685 user_id: The MXID of a local user
686 other_user_id: The MXID of the other user
687
688 Returns:
689 A set of room ID's that the users share.
690 """
691
692 def _get_shared_rooms_for_users_txn(txn):
693 txn.execute(
694 """
695 SELECT p1.room_id
696 FROM users_in_public_rooms as p1
697 INNER JOIN users_in_public_rooms as p2
698 ON p1.room_id = p2.room_id
699 AND p1.user_id = ?
700 AND p2.user_id = ?
701 UNION
702 SELECT room_id
703 FROM users_who_share_private_rooms
704 WHERE
705 user_id = ?
706 AND other_user_id = ?
707 """,
708 (user_id, other_user_id, user_id, other_user_id),
709 )
710 rows = self.db_pool.cursor_to_dict(txn)
711 return rows
712
713 rows = await self.db_pool.runInteraction(
714 "get_shared_rooms_for_users", _get_shared_rooms_for_users_txn
715 )
716
717 return {row["room_id"] for row in rows}
718
719 async def get_user_directory_stream_pos(self) -> int:
720 return await self.db_pool.simple_select_one_onecol(
667721 table="user_directory_stream_pos",
668722 keyvalues={},
669723 retcol="stream_id",
1212 # See the License for the specific language governing permissions and
1313 # limitations under the License.
1414
15 import operator
16
1715 from synapse.storage._base import SQLBaseStore
1816 from synapse.util.caches.descriptors import cached, cachedList
1917
2018
2119 class UserErasureWorkerStore(SQLBaseStore):
2220 @cached()
23 def is_user_erased(self, user_id):
21 async def is_user_erased(self, user_id: str) -> bool:
2422 """
2523 Check if the given user id has requested erasure
2624
2725 Args:
28 user_id (str): full user id to check
26 user_id: full user id to check
2927
3028 Returns:
31 Deferred[bool]: True if the user has requested erasure
29 True if the user has requested erasure
3230 """
33 return self.db_pool.simple_select_onecol(
31 result = await self.db_pool.simple_select_onecol(
3432 table="erased_users",
3533 keyvalues={"user_id": user_id},
3634 retcol="1",
3735 desc="is_user_erased",
38 ).addCallback(operator.truth)
36 )
37 return bool(result)
3938
40 @cachedList(
41 cached_method_name="is_user_erased", list_name="user_ids", inlineCallbacks=True
42 )
43 def are_users_erased(self, user_ids):
39 @cachedList(cached_method_name="is_user_erased", list_name="user_ids")
40 async def are_users_erased(self, user_ids):
4441 """
4542 Checks which users in a list have requested erasure
4643
4845 user_ids (iterable[str]): full user id to check
4946
5047 Returns:
51 Deferred[dict[str, bool]]:
48 dict[str, bool]:
5249 for each user, whether the user has requested erasure.
5350 """
5451 # this serves the dual purpose of (a) making sure we can do len and
5552 # iterate it multiple times, and (b) avoiding duplicates.
5653 user_ids = tuple(set(user_ids))
5754
58 rows = yield self.db_pool.simple_select_many_batch(
55 rows = await self.db_pool.simple_select_many_batch(
5956 table="erased_users",
6057 column="user_id",
6158 iterable=user_ids,
6461 )
6562 erased_users = {row["user_id"] for row in rows}
6663
67 res = {u: u in erased_users for u in user_ids}
68 return res
64 return {u: u in erased_users for u in user_ids}
6965
7066
7167 class UserErasureStore(UserErasureWorkerStore):
72 def mark_user_erased(self, user_id: str) -> None:
68 async def mark_user_erased(self, user_id: str) -> None:
7369 """Indicate that user_id wishes their message history to be erased.
7470
7571 Args:
8783
8884 self._invalidate_cache_and_stream(txn, self.is_user_erased, (user_id,))
8985
90 return self.db_pool.runInteraction("mark_user_erased", f)
86 await self.db_pool.runInteraction("mark_user_erased", f)
9187
92 def mark_user_not_erased(self, user_id: str) -> None:
88 async def mark_user_not_erased(self, user_id: str) -> None:
9389 """Indicate that user_id is no longer erased.
9490
9591 Args:
109105
110106 self._invalidate_cache_and_stream(txn, self.is_user_erased, (user_id,))
111107
112 return self.db_pool.runInteraction("mark_user_not_erased", f)
108 await self.db_pool.runInteraction("mark_user_not_erased", f)
1515 import logging
1616 from collections import namedtuple
1717 from typing import Dict, Iterable, List, Set, Tuple
18
19 from twisted.internet import defer
2018
2119 from synapse.api.constants import EventTypes
2220 from synapse.storage._base import SQLBaseStore
102100 )
103101
104102 @cached(max_entries=10000, iterable=True)
105 def get_state_group_delta(self, state_group):
103 async def get_state_group_delta(self, state_group):
106104 """Given a state group try to return a previous group and a delta between
107105 the old and the new.
108106
134132 {(row["type"], row["state_key"]): row["event_id"] for row in delta_ids},
135133 )
136134
137 return self.db_pool.runInteraction(
135 return await self.db_pool.runInteraction(
138136 "get_state_group_delta", _get_state_group_delta_txn
139137 )
140138
366364 fetched_keys=non_member_types,
367365 )
368366
369 def store_state_group(
367 async def store_state_group(
370368 self, event_id, room_id, prev_group, delta_ids, current_state_ids
371 ):
369 ) -> int:
372370 """Store a new set of state, returning a newly assigned state group.
373371
374372 Args:
382380 to event_id.
383381
384382 Returns:
385 Deferred[int]: The state group ID
383 The state group ID
386384 """
387385
388386 def _store_state_group_txn(txn):
483481
484482 return state_group
485483
486 return self.db_pool.runInteraction("store_state_group", _store_state_group_txn)
487
488 def purge_unreferenced_state_groups(
484 return await self.db_pool.runInteraction(
485 "store_state_group", _store_state_group_txn
486 )
487
488 async def purge_unreferenced_state_groups(
489489 self, room_id: str, state_groups_to_delete
490 ) -> defer.Deferred:
490 ) -> None:
491491 """Deletes no longer referenced state groups and de-deltas any state
492492 groups that reference them.
493493
498498 to delete.
499499 """
500500
501 return self.db_pool.runInteraction(
501 await self.db_pool.runInteraction(
502502 "purge_unreferenced_state_groups",
503503 self._purge_unreferenced_state_groups,
504504 room_id,
593593
594594 return {row["state_group"]: row["prev_state_group"] for row in rows}
595595
596 def purge_room_state(self, room_id, state_groups_to_delete):
596 async def purge_room_state(self, room_id, state_groups_to_delete):
597597 """Deletes all record of a room from state tables
598598
599599 Args:
601601 state_groups_to_delete (list[int]): State groups to delete
602602 """
603603
604 return self.db_pool.runInteraction(
604 await self.db_pool.runInteraction(
605605 "purge_room_state",
606606 self._purge_room_state_txn,
607607 room_id,
2121
2222
2323 @attr.s(slots=True, frozen=True)
24 class FetchKeyResult(object):
24 class FetchKeyResult:
2525 verify_key = attr.ib() # VerifyKey: the key itself
2626 valid_until_ts = attr.ib() # int: how long we can use this key for
6868 )
6969
7070
71 class _EventPeristenceQueue(object):
71 class _EventPeristenceQueue:
7272 """Queues up events so that they can be persisted in bulk with only one
7373 concurrent transaction per room.
7474 """
171171 pass
172172
173173
174 class EventsPersistenceStorage(object):
174 class EventsPersistenceStorage:
175175 """High level interface for handling persisting newly received events.
176176
177177 Takes care of batching up events by room, and calculating the necessary
4646 pass
4747
4848
49 OUTDATED_SCHEMA_ON_WORKER_ERROR = (
50 "Expected database schema version %i but got %i: run the main synapse process to "
51 "upgrade the database schema before starting worker processes."
52 )
53
54 EMPTY_DATABASE_ON_WORKER_ERROR = (
55 "Uninitialised database: run the main synapse process to prepare the database "
56 "schema before starting worker processes."
57 )
58
59 UNAPPLIED_DELTA_ON_WORKER_ERROR = (
60 "Database schema delta %s has not been applied: run the main synapse process to "
61 "upgrade the database schema before starting worker processes."
62 )
63
64
4965 def prepare_database(db_conn, database_engine, config, databases=["main", "state"]):
5066 """Prepares a physical database for usage. Will either create all necessary tables
5167 or upgrade from an older schema version.
6581
6682 try:
6783 cur = db_conn.cursor()
84
85 logger.info("%r: Checking existing schema version", databases)
6886 version_info = _get_or_create_schema_state(cur, database_engine)
6987
7088 if version_info:
7189 user_version, delta_files, upgraded = version_info
72
90 logger.info(
91 "%r: Existing schema is %i (+%i deltas)",
92 databases,
93 user_version,
94 len(delta_files),
95 )
96
97 # config should only be None when we are preparing an in-memory SQLite db,
98 # which should be empty.
7399 if config is None:
74 if user_version != SCHEMA_VERSION:
75 # If we don't pass in a config file then we are expecting to
76 # have already upgraded the DB.
77 raise UpgradeDatabaseException(
78 "Expected database schema version %i but got %i"
79 % (SCHEMA_VERSION, user_version)
80 )
81 else:
82 _upgrade_existing_database(
83 cur,
84 user_version,
85 delta_files,
86 upgraded,
87 database_engine,
88 config,
89 databases=databases,
100 raise ValueError(
101 "config==None in prepare_database, but databse is not empty"
90102 )
103
104 # if it's a worker app, refuse to upgrade the database, to avoid multiple
105 # workers doing it at once.
106 if config.worker_app is not None and user_version != SCHEMA_VERSION:
107 raise UpgradeDatabaseException(
108 OUTDATED_SCHEMA_ON_WORKER_ERROR % (SCHEMA_VERSION, user_version)
109 )
110
111 _upgrade_existing_database(
112 cur,
113 user_version,
114 delta_files,
115 upgraded,
116 database_engine,
117 config,
118 databases=databases,
119 )
91120 else:
121 logger.info("%r: Initialising new database", databases)
122
123 # if it's a worker app, refuse to upgrade the database, to avoid multiple
124 # workers doing it at once.
125 if config and config.worker_app is not None:
126 raise UpgradeDatabaseException(EMPTY_DATABASE_ON_WORKER_ERROR)
127
92128 _setup_new_database(cur, database_engine, databases=databases)
93129
94130 # check if any of our configured dynamic modules want a database
294330 else:
295331 assert config
296332
333 is_worker = config and config.worker_app is not None
334
297335 if current_version > SCHEMA_VERSION:
298336 raise ValueError(
299337 "Cannot use this database as it is too "
321359 specific_engine_extensions = (".sqlite", ".postgres")
322360
323361 for v in range(start_ver, SCHEMA_VERSION + 1):
324 logger.info("Upgrading schema to v%d", v)
362 logger.info("Applying schema deltas for v%d", v)
325363
326364 # We need to search both the global and per data store schema
327365 # directories for schema updates.
381419 continue
382420
383421 root_name, ext = os.path.splitext(file_name)
422
384423 if ext == ".py":
385424 # This is a python upgrade module. We need to import into some
386425 # package and then execute its `run_upgrade` function.
426 if is_worker:
427 raise PrepareDatabaseException(
428 UNAPPLIED_DELTA_ON_WORKER_ERROR % relative_path
429 )
430
387431 module_name = "synapse.storage.v%d_%s" % (v, root_name)
388432 with open(absolute_path) as python_file:
389433 module = imp.load_source(module_name, absolute_path, python_file)
398442 continue
399443 elif ext == ".sql":
400444 # A plain old .sql file, just read and execute it
445 if is_worker:
446 raise PrepareDatabaseException(
447 UNAPPLIED_DELTA_ON_WORKER_ERROR % relative_path
448 )
401449 logger.info("Applying schema %s", relative_path)
402450 executescript(cur, absolute_path)
403451 elif ext == specific_engine_extension and root_name.endswith(".sql"):
404452 # A .sql file specific to our engine; just read and execute it
453 if is_worker:
454 raise PrepareDatabaseException(
455 UNAPPLIED_DELTA_ON_WORKER_ERROR % relative_path
456 )
405457 logger.info("Applying engine-specific schema %s", relative_path)
406458 executescript(cur, absolute_path)
407459 elif ext in specific_engine_extensions and root_name.endswith(".sql"):
430482 ),
431483 (v, True),
432484 )
485
486 logger.info("Schema now up to date")
433487
434488
435489 def _apply_module_schemas(txn, database_engine, config):
568622
569623
570624 @attr.s()
571 class _DirectoryListing(object):
625 class _DirectoryListing:
572626 """Helper class to store schema file name and the
573627 absolute path to it.
574628
+0
-69
synapse/storage/presence.py less more
0 # -*- coding: utf-8 -*-
1 # Copyright 2014-2016 OpenMarket Ltd
2 #
3 # Licensed under the Apache License, Version 2.0 (the "License");
4 # you may not use this file except in compliance with the License.
5 # You may obtain a copy of the License at
6 #
7 # http://www.apache.org/licenses/LICENSE-2.0
8 #
9 # Unless required by applicable law or agreed to in writing, software
10 # distributed under the License is distributed on an "AS IS" BASIS,
11 # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 # See the License for the specific language governing permissions and
13 # limitations under the License.
14
15 from collections import namedtuple
16
17 from synapse.api.constants import PresenceState
18
19
20 class UserPresenceState(
21 namedtuple(
22 "UserPresenceState",
23 (
24 "user_id",
25 "state",
26 "last_active_ts",
27 "last_federation_update_ts",
28 "last_user_sync_ts",
29 "status_msg",
30 "currently_active",
31 ),
32 )
33 ):
34 """Represents the current presence state of the user.
35
36 user_id (str)
37 last_active (int): Time in msec that the user last interacted with server.
38 last_federation_update (int): Time in msec since either a) we sent a presence
39 update to other servers or b) we received a presence update, depending
40 on if is a local user or not.
41 last_user_sync (int): Time in msec that the user last *completed* a sync
42 (or event stream).
43 status_msg (str): User set status message.
44 """
45
46 def as_dict(self):
47 return dict(self._asdict())
48
49 @staticmethod
50 def from_dict(d):
51 return UserPresenceState(**d)
52
53 def copy_and_replace(self, **kwargs):
54 return self._replace(**kwargs)
55
56 @classmethod
57 def default(cls, user_id):
58 """Returns a default presence state.
59 """
60 return cls(
61 user_id=user_id,
62 state=PresenceState.OFFLINE,
63 last_active_ts=0,
64 last_federation_update_ts=0,
65 last_user_sync_ts=0,
66 status_msg=None,
67 currently_active=False,
68 )
1919 logger = logging.getLogger(__name__)
2020
2121
22 class PurgeEventsStorage(object):
22 class PurgeEventsStorage:
2323 """High level interface for purging rooms and event history.
2424 """
2525
2222
2323
2424 @attr.s
25 class PaginationChunk(object):
25 class PaginationChunk:
2626 """Returned by relation pagination APIs.
2727
2828 Attributes:
5050
5151
5252 @attr.s(frozen=True, slots=True)
53 class RelationPaginationToken(object):
53 class RelationPaginationToken:
5454 """Pagination token for relation pagination API.
5555
5656 As the results are in topological order, we can use the
8181
8282
8383 @attr.s(frozen=True, slots=True)
84 class AggregationPaginationToken(object):
84 class AggregationPaginationToken:
8585 """Pagination token for relation aggregation pagination API.
8686
8787 As the results are order by count and then MAX(stream_ordering) of the
2828
2929
3030 @attr.s(slots=True)
31 class StateFilter(object):
31 class StateFilter:
3232 """A filter used when querying for state.
3333
3434 Attributes:
325325 return member_filter, non_member_filter
326326
327327
328 class StateGroupStorage(object):
328 class StateGroupStorage:
329329 """High level interface to fetching state for event.
330330 """
331331
332332 def __init__(self, hs, stores):
333333 self.stores = stores
334334
335 def get_state_group_delta(self, state_group: int):
335 async def get_state_group_delta(self, state_group: int):
336336 """Given a state group try to return a previous group and a delta between
337337 the old and the new.
338338
340340 state_group: The state group used to retrieve state deltas.
341341
342342 Returns:
343 Deferred[Tuple[Optional[int], Optional[StateMap[str]]]]:
343 Tuple[Optional[int], Optional[StateMap[str]]]:
344344 (prev_group, delta_ids)
345345 """
346346
347 return self.stores.state.get_state_group_delta(state_group)
347 return await self.stores.state.get_state_group_delta(state_group)
348348
349349 async def get_state_groups_ids(
350350 self, _room_id: str, event_ids: Iterable[str]
524524 state_filter: The state filter used to fetch state from the database.
525525
526526 Returns:
527 A deferred dict from (type, state_key) -> state_event
527 A dict from (type, state_key) -> state_event
528528 """
529529 state_map = await self.get_state_ids_for_events([event_id], state_filter)
530530 return state_map[event_id]
545545 """
546546 return self.stores.state._get_state_for_groups(groups, state_filter)
547547
548 def store_state_group(
548 async def store_state_group(
549549 self,
550550 event_id: str,
551551 room_id: str,
552552 prev_group: Optional[int],
553553 delta_ids: Optional[dict],
554554 current_state_ids: dict,
555 ):
555 ) -> int:
556556 """Store a new set of state, returning a newly assigned state group.
557557
558558 Args:
566566 to event_id.
567567
568568 Returns:
569 Deferred[int]: The state group ID
570 """
571 return self.stores.state.store_state_group(
569 The state group ID
570 """
571 return await self.stores.state.store_state_group(
572572 event_id, room_id, prev_group, delta_ids, current_state_ids
573573 )
1313 # limitations under the License.
1414
1515 import contextlib
16 import heapq
17 import logging
1618 import threading
1719 from collections import deque
18 from typing import Dict, Set, Tuple
20 from typing import Dict, List, Set
1921
2022 from typing_extensions import Deque
2123
2224 from synapse.storage.database import DatabasePool, LoggingTransaction
2325 from synapse.storage.util.sequence import PostgresSequenceGenerator
2426
25
26 class IdGenerator(object):
27 logger = logging.getLogger(__name__)
28
29
30 class IdGenerator:
2731 def __init__(self, db_conn, table, column):
2832 self._lock = threading.Lock()
2933 self._next_id = _load_current_id(db_conn, table, column)
4650 Returns:
4751 int
4852 """
53 # debug logging for https://github.com/matrix-org/synapse/issues/7968
54 logger.info("initialising stream generator for %s(%s)", table, column)
4955 cur = db_conn.cursor()
5056 if step == 1:
5157 cur.execute("SELECT MAX(%s) FROM %s" % (column, table))
5763 return (max if step > 0 else min)(current_id, step)
5864
5965
60 class StreamIdGenerator(object):
66 class StreamIdGenerator:
6167 """Used to generate new stream ids when persisting events while keeping
6268 track of which transactions have been completed.
6369
7985 upwards, -1 to grow downwards.
8086
8187 Usage:
82 with stream_id_gen.get_next() as stream_id:
88 with await stream_id_gen.get_next() as stream_id:
8389 # ... persist event ...
8490 """
8591
94100 )
95101 self._unfinished_ids = deque() # type: Deque[int]
96102
97 def get_next(self):
103 async def get_next(self):
98104 """
99105 Usage:
100 with stream_id_gen.get_next() as stream_id:
106 with await stream_id_gen.get_next() as stream_id:
101107 # ... persist event ...
102108 """
103109 with self._lock:
116122
117123 return manager()
118124
119 def get_next_mult(self, n):
125 async def get_next_mult(self, n):
120126 """
121127 Usage:
122 with stream_id_gen.get_next(n) as stream_ids:
128 with await stream_id_gen.get_next(n) as stream_ids:
123129 # ... persist events ...
124130 """
125131 with self._lock:
157163
158164 return self._current
159165
160
161 class ChainedIdGenerator(object):
162 """Used to generate new stream ids where the stream must be kept in sync
163 with another stream. It generates pairs of IDs, the first element is an
164 integer ID for this stream, the second element is the ID for the stream
165 that this stream needs to be kept in sync with."""
166
167 def __init__(self, chained_generator, db_conn, table, column):
168 self.chained_generator = chained_generator
169 self._table = table
170 self._lock = threading.Lock()
171 self._current_max = _load_current_id(db_conn, table, column)
172 self._unfinished_ids = deque() # type: Deque[Tuple[int, int]]
173
174 def get_next(self):
175 """
176 Usage:
177 with stream_id_gen.get_next() as (stream_id, chained_id):
178 # ... persist event ...
179 """
180 with self._lock:
181 self._current_max += 1
182 next_id = self._current_max
183 chained_id = self.chained_generator.get_current_token()
184
185 self._unfinished_ids.append((next_id, chained_id))
186
187 @contextlib.contextmanager
188 def manager():
189 try:
190 yield (next_id, chained_id)
191 finally:
192 with self._lock:
193 self._unfinished_ids.remove((next_id, chained_id))
194
195 return manager()
196
197 def get_current_token(self):
198 """Returns the maximum stream id such that all stream ids less than or
199 equal to it have been successfully persisted.
200 """
201 with self._lock:
202 if self._unfinished_ids:
203 stream_id, chained_id = self._unfinished_ids[0]
204 return stream_id - 1, chained_id
205
206 return self._current_max, self.chained_generator.get_current_token()
207
208 def advance(self, token: int):
209 """Stub implementation for advancing the token when receiving updates
210 over replication; raises an exception as this instance should be the
211 only source of updates.
212 """
213
214 raise Exception(
215 "Attempted to advance token on source for table %r", self._table
216 )
166 def get_current_token_for_writer(self, instance_name: str) -> int:
167 """Returns the position of the given writer.
168
169 For streams with single writers this is equivalent to
170 `get_current_token`.
171 """
172 return self.get_current_token()
217173
218174
219175 class MultiWriterIdGenerator:
233189 id_column: Column that stores the stream ID.
234190 sequence_name: The name of the postgres sequence used to generate new
235191 IDs.
192 positive: Whether the IDs are positive (true) or negative (false).
193 When using negative IDs we go backwards from -1 to -2, -3, etc.
236194 """
237195
238196 def __init__(
244202 instance_column: str,
245203 id_column: str,
246204 sequence_name: str,
205 positive: bool = True,
247206 ):
248207 self._db = db
249208 self._instance_name = instance_name
209 self._positive = positive
210 self._return_factor = 1 if positive else -1
250211
251212 # We lock as some functions may be called from DB threads.
252213 self._lock = threading.Lock()
253214
215 # Note: If we are a negative stream then we still store all the IDs as
216 # positive to make life easier for us, and simply negate the IDs when we
217 # return them.
254218 self._current_positions = self._load_current_ids(
255219 db_conn, table, instance_column, id_column
256220 )
259223 # should be less than the minimum of this set (if not empty).
260224 self._unfinished_ids = set() # type: Set[int]
261225
226 # We track the max position where we know everything before has been
227 # persisted. This is done by a) looking at the min across all instances
228 # and b) noting that if we have seen a run of persisted positions
229 # without gaps (e.g. 5, 6, 7) then we can skip forward (e.g. to 7).
230 #
231 # Note: There is no guarentee that the IDs generated by the sequence
232 # will be gapless; gaps can form when e.g. a transaction was rolled
233 # back. This means that sometimes we won't be able to skip forward the
234 # position even though everything has been persisted. However, since
235 # gaps should be relatively rare it's still worth doing the book keeping
236 # that allows us to skip forwards when there are gapless runs of
237 # positions.
238 self._persisted_upto_position = (
239 min(self._current_positions.values()) if self._current_positions else 0
240 )
241 self._known_persisted_positions = [] # type: List[int]
242
262243 self._sequence_gen = PostgresSequenceGenerator(sequence_name)
263244
264245 def _load_current_ids(
265246 self, db_conn, table: str, instance_column: str, id_column: str
266247 ) -> Dict[str, int]:
248 # If positive stream aggregate via MAX. For negative stream use MIN
249 # *and* negate the result to get a positive number.
267250 sql = """
268 SELECT %(instance)s, MAX(%(id)s) FROM %(table)s
251 SELECT %(instance)s, %(agg)s(%(id)s) FROM %(table)s
269252 GROUP BY %(instance)s
270253 """ % {
271254 "instance": instance_column,
272255 "id": id_column,
273256 "table": table,
257 "agg": "MAX" if self._positive else "-MIN",
274258 }
275259
276260 cur = db_conn.cursor()
283267
284268 return current_positions
285269
286 def _load_next_id_txn(self, txn):
270 def _load_next_id_txn(self, txn) -> int:
287271 return self._sequence_gen.get_next_id_txn(txn)
272
273 def _load_next_mult_id_txn(self, txn, n: int) -> List[int]:
274 return self._sequence_gen.get_next_mult_txn(txn, n)
288275
289276 async def get_next(self):
290277 """
297284 # Assert the fetched ID is actually greater than what we currently
298285 # believe the ID to be. If not, then the sequence and table have got
299286 # out of sync somehow.
300 assert self.get_current_token() < next_id
301
302 with self._lock:
287 with self._lock:
288 assert self._current_positions.get(self._instance_name, 0) < next_id
289
303290 self._unfinished_ids.add(next_id)
304291
305292 @contextlib.contextmanager
306293 def manager():
307294 try:
308 yield next_id
295 # Multiply by the return factor so that the ID has correct sign.
296 yield self._return_factor * next_id
309297 finally:
310298 self._mark_id_as_finished(next_id)
311299
312300 return manager()
313301
302 async def get_next_mult(self, n: int):
303 """
304 Usage:
305 with await stream_id_gen.get_next_mult(5) as stream_ids:
306 # ... persist events ...
307 """
308 next_ids = await self._db.runInteraction(
309 "_load_next_mult_id", self._load_next_mult_id_txn, n
310 )
311
312 # Assert the fetched ID is actually greater than any ID we've already
313 # seen. If not, then the sequence and table have got out of sync
314 # somehow.
315 with self._lock:
316 assert max(self._current_positions.values(), default=0) < min(next_ids)
317
318 self._unfinished_ids.update(next_ids)
319
320 @contextlib.contextmanager
321 def manager():
322 try:
323 yield [self._return_factor * i for i in next_ids]
324 finally:
325 for i in next_ids:
326 self._mark_id_as_finished(i)
327
328 return manager()
329
314330 def get_next_txn(self, txn: LoggingTransaction):
315331 """
316332 Usage:
327343 txn.call_after(self._mark_id_as_finished, next_id)
328344 txn.call_on_exception(self._mark_id_as_finished, next_id)
329345
330 return next_id
346 return self._return_factor * next_id
331347
332348 def _mark_id_as_finished(self, next_id: int):
333349 """The ID has finished being processed so we should advance the
343359 curr = self._current_positions.get(self._instance_name, 0)
344360 self._current_positions[self._instance_name] = max(curr, next_id)
345361
346 def get_current_token(self, instance_name: str = None) -> int:
347 """Gets the current position of a named writer (defaults to current
348 instance).
349
350 Returns 0 if we don't have a position for the named writer (likely due
351 to it being a new writer).
352 """
353
354 if instance_name is None:
355 instance_name = self._instance_name
356
357 with self._lock:
358 return self._current_positions.get(instance_name, 0)
362 self._add_persisted_position(next_id)
363
364 def get_current_token(self) -> int:
365 """Returns the maximum stream id such that all stream ids less than or
366 equal to it have been successfully persisted.
367 """
368
369 # Currently we don't support this operation, as it's not obvious how to
370 # condense the stream positions of multiple writers into a single int.
371 raise NotImplementedError()
372
373 def get_current_token_for_writer(self, instance_name: str) -> int:
374 """Returns the position of the given writer.
375 """
376
377 with self._lock:
378 return self._return_factor * self._current_positions.get(instance_name, 0)
359379
360380 def get_positions(self) -> Dict[str, int]:
361381 """Get a copy of the current positon map.
362382 """
363383
364384 with self._lock:
365 return dict(self._current_positions)
385 return {
386 name: self._return_factor * i
387 for name, i in self._current_positions.items()
388 }
366389
367390 def advance(self, instance_name: str, new_id: int):
368391 """Advance the postion of the named writer to the given ID, if greater
369392 than existing entry.
370393 """
371394
395 new_id *= self._return_factor
396
372397 with self._lock:
373398 self._current_positions[instance_name] = max(
374399 new_id, self._current_positions.get(instance_name, 0)
375400 )
401
402 self._add_persisted_position(new_id)
403
404 def get_persisted_upto_position(self) -> int:
405 """Get the max position where all previous positions have been
406 persisted.
407
408 Note: In the worst case scenario this will be equal to the minimum
409 position across writers. This means that the returned position here can
410 lag if one writer doesn't write very often.
411 """
412
413 with self._lock:
414 return self._return_factor * self._persisted_upto_position
415
416 def _add_persisted_position(self, new_id: int):
417 """Record that we have persisted a position.
418
419 This is used to keep the `_current_positions` up to date.
420 """
421
422 # We require that the lock is locked by caller
423 assert self._lock.locked()
424
425 heapq.heappush(self._known_persisted_positions, new_id)
426
427 # We move the current min position up if the minimum current positions
428 # of all instances is higher (since by definition all positions less
429 # that that have been persisted).
430 min_curr = min(self._current_positions.values())
431 self._persisted_upto_position = max(min_curr, self._persisted_upto_position)
432
433 # We now iterate through the seen positions, discarding those that are
434 # less than the current min positions, and incrementing the min position
435 # if its exactly one greater.
436 #
437 # This is also where we discard items from `_known_persisted_positions`
438 # (to ensure the list doesn't infinitely grow).
439 while self._known_persisted_positions:
440 if self._known_persisted_positions[0] <= self._persisted_upto_position:
441 heapq.heappop(self._known_persisted_positions)
442 elif (
443 self._known_persisted_positions[0] == self._persisted_upto_position + 1
444 ):
445 heapq.heappop(self._known_persisted_positions)
446 self._persisted_upto_position += 1
447 else:
448 # There was a gap in seen positions, so there is nothing more to
449 # do.
450 break
1313 # limitations under the License.
1414 import abc
1515 import threading
16 from typing import Callable, Optional
16 from typing import Callable, List, Optional
1717
1818 from synapse.storage.engines import BaseDatabaseEngine, PostgresEngine
1919 from synapse.storage.types import Cursor
3737 def get_next_id_txn(self, txn: Cursor) -> int:
3838 txn.execute("SELECT nextval(?)", (self._sequence_name,))
3939 return txn.fetchone()[0]
40
41 def get_next_mult_txn(self, txn: Cursor, n: int) -> List[int]:
42 txn.execute(
43 "SELECT nextval(?) FROM generate_series(1, ?)", (self._sequence_name, n)
44 )
45 return [i for (i,) in txn]
4046
4147
4248 GetFirstCallbackType = Callable[[Cursor], int]
2424 MAX_LIMIT = 1000
2525
2626
27 class SourcePaginationConfig(object):
27 class SourcePaginationConfig:
2828
2929 """A configuration object which stores pagination parameters for a
3030 specific event source."""
4444 )
4545
4646
47 class PaginationConfig(object):
47 class PaginationConfig:
4848
4949 """A configuration object which stores pagination parameters."""
5050
2222 from synapse.types import StreamToken
2323
2424
25 class EventSources(object):
25 class EventSources:
2626 SOURCE_TYPES = {
2727 "room": RoomEventSource,
2828 "presence": PresenceEventSource,
3838 self.store = hs.get_datastore()
3939
4040 def get_current_token(self) -> StreamToken:
41 push_rules_key, _ = self.store.get_push_rules_stream_token()
41 push_rules_key = self.store.get_max_push_rules_stream_id()
4242 to_device_key = self.store.get_to_device_stream_token()
4343 device_list_key = self.store.get_device_stream_token()
4444 groups_key = self.store.get_group_stream_token()
1717 import string
1818 import sys
1919 from collections import namedtuple
20 from typing import Any, Dict, Tuple, Type, TypeVar
20 from typing import Any, Dict, Mapping, MutableMapping, Tuple, Type, TypeVar
2121
2222 import attr
2323 from signedjson.key import decode_verify_key_bytes
4040 # Define a state map type from type/state_key to T (usually an event ID or
4141 # event)
4242 T = TypeVar("T")
43 StateMap = Dict[Tuple[str, str], T]
44
43 StateKey = Tuple[str, str]
44 StateMap = Mapping[StateKey, T]
45 MutableStateMap = MutableMapping[StateKey, T]
4546
4647 # the type of a JSON-serialisable dict. This could be made stronger, but it will
4748 # do for now.
5051
5152 class Requester(
5253 namedtuple(
53 "Requester", ["user", "access_token_id", "is_guest", "device_id", "app_service"]
54 "Requester",
55 [
56 "user",
57 "access_token_id",
58 "is_guest",
59 "shadow_banned",
60 "device_id",
61 "app_service",
62 ],
5463 )
5564 ):
5665 """
6170 access_token_id (int|None): *ID* of the access token used for this
6271 request, or None if it came via the appservice API or similar
6372 is_guest (bool): True if the user making this request is a guest user
73 shadow_banned (bool): True if the user making this request has been shadow-banned.
6474 device_id (str|None): device_id which was set at authentication time
6575 app_service (ApplicationService|None): the AS requesting on behalf of the user
6676 """
7686 "user_id": self.user.to_string(),
7787 "access_token_id": self.access_token_id,
7888 "is_guest": self.is_guest,
89 "shadow_banned": self.shadow_banned,
7990 "device_id": self.device_id,
8091 "app_server_id": self.app_service.id if self.app_service else None,
8192 }
100111 user=UserID.from_string(input["user_id"]),
101112 access_token_id=input["access_token_id"],
102113 is_guest=input["is_guest"],
114 shadow_banned=input["shadow_banned"],
103115 device_id=input["device_id"],
104116 app_service=appservice,
105117 )
106118
107119
108120 def create_requester(
109 user_id, access_token_id=None, is_guest=False, device_id=None, app_service=None
121 user_id,
122 access_token_id=None,
123 is_guest=False,
124 shadow_banned=False,
125 device_id=None,
126 app_service=None,
110127 ):
111128 """
112129 Create a new ``Requester`` object
116133 access_token_id (int|None): *ID* of the access token used for this
117134 request, or None if it came via the appservice API or similar
118135 is_guest (bool): True if the user making this request is a guest user
136 shadow_banned (bool): True if the user making this request is shadow-banned.
119137 device_id (str|None): device_id which was set at authentication time
120138 app_service (ApplicationService|None): the AS requesting on behalf of the user
121139
124142 """
125143 if not isinstance(user_id, UserID):
126144 user_id = UserID.from_string(user_id)
127 return Requester(user_id, access_token_id, is_guest, device_id, app_service)
145 return Requester(
146 user_id, access_token_id, is_guest, shadow_banned, device_id, app_service
147 )
128148
129149
130150 def get_domain_from_id(string):
508528
509529
510530 @attr.s(slots=True)
511 class ReadReceipt(object):
531 class ReadReceipt:
512532 """Information about a read-receipt"""
513533
514534 room_id = attr.ib()
2424
2525 logger = logging.getLogger(__name__)
2626
27 # Create a custom encoder to reduce the whitespace produced by JSON encoding.
28 json_encoder = json.JSONEncoder(separators=(",", ":"))
27
28 def _reject_invalid_json(val):
29 """Do not allow Infinity, -Infinity, or NaN values in JSON."""
30 raise ValueError("Invalid JSON value: '%s'" % val)
31
32
33 # Create a custom encoder to reduce the whitespace produced by JSON encoding and
34 # ensure that valid JSON is produced.
35 json_encoder = json.JSONEncoder(allow_nan=False, separators=(",", ":"))
36
37 # Create a custom decoder to reject Python extensions to JSON.
38 json_decoder = json.JSONDecoder(parse_constant=_reject_invalid_json)
2939
3040
3141 def unwrapFirstError(failure):
3545
3646
3747 @attr.s
38 class Clock(object):
48 class Clock:
3949 """
4050 A Clock wraps a Twisted reactor and provides utilities on top of it.
4151
1919 from typing import Dict, Sequence, Set, Union
2020
2121 import attr
22 from typing_extensions import ContextManager
2223
2324 from twisted.internet import defer
2425 from twisted.internet.defer import CancelledError
3435 logger = logging.getLogger(__name__)
3536
3637
37 class ObservableDeferred(object):
38 class ObservableDeferred:
3839 """Wraps a deferred object so that we can add observer deferreds. These
3940 observer deferreds do not affect the callback chain of the original
4041 deferred.
186187 ).addErrback(unwrapFirstError)
187188
188189
189 class Linearizer(object):
190 class Linearizer:
190191 """Limits concurrent access to resources based on a key. Useful to ensure
191192 only a few things happen at a time on a given resource.
192193
336337 return new_defer
337338
338339
339 class ReadWriteLock(object):
340 """A deferred style read write lock.
340 class ReadWriteLock:
341 """An async read write lock.
341342
342343 Example:
343344
344 with (yield read_write_lock.read("test_key")):
345 with await read_write_lock.read("test_key"):
345346 # do some work
346347 """
347348
364365 # Latest writer queued
365366 self.key_to_current_writer = {} # type: Dict[str, defer.Deferred]
366367
367 @defer.inlineCallbacks
368 def read(self, key):
368 async def read(self, key: str) -> ContextManager:
369369 new_defer = defer.Deferred()
370370
371371 curr_readers = self.key_to_current_readers.setdefault(key, set())
375375
376376 # We wait for the latest writer to finish writing. We can safely ignore
377377 # any existing readers... as they're readers.
378 yield make_deferred_yieldable(curr_writer)
378 if curr_writer:
379 await make_deferred_yieldable(curr_writer)
379380
380381 @contextmanager
381382 def _ctx_manager():
387388
388389 return _ctx_manager()
389390
390 @defer.inlineCallbacks
391 def write(self, key):
391 async def write(self, key: str) -> ContextManager:
392392 new_defer = defer.Deferred()
393393
394394 curr_readers = self.key_to_current_readers.get(key, set())
404404 curr_readers.clear()
405405 self.key_to_current_writer[key] = new_defer
406406
407 yield make_deferred_yieldable(defer.gatherResults(to_wait_on))
407 await make_deferred_yieldable(defer.gatherResults(to_wait_on))
408408
409409 @contextmanager
410410 def _ctx_manager():
501501
502502
503503 @attr.s(slots=True, frozen=True)
504 class DoneAwaitable(object):
504 class DoneAwaitable:
505505 """Simple awaitable that returns the provided value.
506506 """
507507
4242
4343
4444 @attr.s
45 class CacheMetric(object):
45 class CacheMetric:
4646
4747 _cache = attr.ib()
4848 _cache_type = attr.ib(type=str)
1717 import inspect
1818 import logging
1919 import threading
20 from typing import Any, Tuple, Union, cast
20 from typing import Any, Callable, Generic, Optional, Tuple, TypeVar, Union, cast
2121 from weakref import WeakValueDictionary
2222
2323 from prometheus_client import Gauge
24 from typing_extensions import Protocol
2524
2625 from twisted.internet import defer
2726
3736
3837 CacheKey = Union[Tuple, Any]
3938
40
41 class _CachedFunction(Protocol):
39 F = TypeVar("F", bound=Callable[..., Any])
40
41
42 class _CachedFunction(Generic[F]):
4243 invalidate = None # type: Any
4344 invalidate_all = None # type: Any
4445 invalidate_many = None # type: Any
4647 cache = None # type: Any
4748 num_args = None # type: Any
4849
49 def __name__(self):
50 ...
50 __name__ = None # type: str
51
52 # Note: This function signature is actually fiddled with by the synapse mypy
53 # plugin to a) make it a bound method, and b) remove any `cache_context` arg.
54 __call__ = None # type: F
5155
5256
5357 cache_pending_metric = Gauge(
5963 _CacheSentinel = object()
6064
6165
62 class CacheEntry(object):
66 class CacheEntry:
6367 __slots__ = ["deferred", "callbacks", "invalidated"]
6468
6569 def __init__(self, deferred, callbacks):
7579 self.callbacks.clear()
7680
7781
78 class Cache(object):
82 class Cache:
7983 __slots__ = (
8084 "cache",
8185 "name",
122126
123127 self.name = name
124128 self.keylen = keylen
125 self.thread = None
129 self.thread = None # type: Optional[threading.Thread]
126130 self.metrics = register_cache(
127131 "cache",
128132 name,
283287 self._pending_deferred_cache.clear()
284288
285289
286 class _CacheDescriptorBase(object):
287 def __init__(
288 self, orig: _CachedFunction, num_args, inlineCallbacks, cache_context=False
289 ):
290 class _CacheDescriptorBase:
291 def __init__(self, orig: _CachedFunction, num_args, cache_context=False):
290292 self.orig = orig
291
292 if inlineCallbacks:
293 self.function_to_call = defer.inlineCallbacks(orig)
294 else:
295 self.function_to_call = orig
296293
297294 arg_spec = inspect.getfullargspec(orig)
298295 all_args = arg_spec.args
363360 invalidated) by adding a special "cache_context" argument to the function
364361 and passing that as a kwarg to all caches called. For example::
365362
366 @cachedInlineCallbacks(cache_context=True)
363 @cached(cache_context=True)
367364 def foo(self, key, cache_context):
368365 r1 = yield self.bar1(key, on_invalidate=cache_context.invalidate)
369366 r2 = yield self.bar2(key, on_invalidate=cache_context.invalidate)
381378 max_entries=1000,
382379 num_args=None,
383380 tree=False,
384 inlineCallbacks=False,
385381 cache_context=False,
386382 iterable=False,
387383 ):
388384
389 super(CacheDescriptor, self).__init__(
390 orig,
391 num_args=num_args,
392 inlineCallbacks=inlineCallbacks,
393 cache_context=cache_context,
394 )
385 super().__init__(orig, num_args=num_args, cache_context=cache_context)
395386
396387 self.max_entries = max_entries
397388 self.tree = tree
464455 observer = defer.succeed(cached_result_d)
465456
466457 except KeyError:
467 ret = defer.maybeDeferred(
468 preserve_fn(self.function_to_call), obj, *args, **kwargs
469 )
458 ret = defer.maybeDeferred(preserve_fn(self.orig), obj, *args, **kwargs)
470459
471460 def onErr(f):
472461 cache.invalidate(cache_key)
509498 of results.
510499 """
511500
512 def __init__(
513 self, orig, cached_method_name, list_name, num_args=None, inlineCallbacks=False
514 ):
501 def __init__(self, orig, cached_method_name, list_name, num_args=None):
515502 """
516503 Args:
517504 orig (function)
520507 num_args (int): number of positional arguments (excluding ``self``,
521508 but including list_name) to use as cache keys. Defaults to all
522509 named args of the function.
523 inlineCallbacks (bool): Whether orig is a generator that should
524 be wrapped by defer.inlineCallbacks
525510 """
526 super(CacheListDescriptor, self).__init__(
527 orig, num_args=num_args, inlineCallbacks=inlineCallbacks
528 )
511 super().__init__(orig, num_args=num_args)
529512
530513 self.list_name = list_name
531514
630613
631614 cached_defers.append(
632615 defer.maybeDeferred(
633 preserve_fn(self.function_to_call), **args_to_call
616 preserve_fn(self.orig), **args_to_call
634617 ).addCallbacks(complete_all, errback)
635618 )
636619
682665
683666
684667 def cached(
685 max_entries=1000, num_args=None, tree=False, cache_context=False, iterable=False
686 ):
687 return lambda orig: CacheDescriptor(
668 max_entries: int = 1000,
669 num_args: Optional[int] = None,
670 tree: bool = False,
671 cache_context: bool = False,
672 iterable: bool = False,
673 ) -> Callable[[F], _CachedFunction[F]]:
674 func = lambda orig: CacheDescriptor(
688675 orig,
689676 max_entries=max_entries,
690677 num_args=num_args,
693680 iterable=iterable,
694681 )
695682
696
697 def cachedInlineCallbacks(
698 max_entries=1000, num_args=None, tree=False, cache_context=False, iterable=False
699 ):
700 return lambda orig: CacheDescriptor(
701 orig,
702 max_entries=max_entries,
703 num_args=num_args,
704 tree=tree,
705 inlineCallbacks=True,
706 cache_context=cache_context,
707 iterable=iterable,
708 )
709
710
711 def cachedList(cached_method_name, list_name, num_args=None, inlineCallbacks=False):
683 return cast(Callable[[F], _CachedFunction[F]], func)
684
685
686 def cachedList(
687 cached_method_name: str, list_name: str, num_args: Optional[int] = None
688 ) -> Callable[[F], _CachedFunction[F]]:
712689 """Creates a descriptor that wraps a function in a `CacheListDescriptor`.
713690
714691 Used to do batch lookups for an already created cache. A single argument
718695 cache.
719696
720697 Args:
721 cached_method_name (str): The name of the single-item lookup method.
698 cached_method_name: The name of the single-item lookup method.
722699 This is only used to find the cache to use.
723 list_name (str): The name of the argument that is the list to use to
700 list_name: The name of the argument that is the list to use to
724701 do batch lookups in the cache.
725 num_args (int): Number of arguments to use as the key in the cache
702 num_args: Number of arguments to use as the key in the cache
726703 (including list_name). Defaults to all named parameters.
727 inlineCallbacks (bool): Should the function be wrapped in an
728 `defer.inlineCallbacks`?
729704
730705 Example:
731706
732 class Example(object):
707 class Example:
733708 @cached(num_args=2)
734709 def do_something(self, first_arg):
735710 ...
738713 def batch_do_something(self, first_arg, second_args):
739714 ...
740715 """
741 return lambda orig: CacheListDescriptor(
716 func = lambda orig: CacheListDescriptor(
742717 orig,
743718 cached_method_name=cached_method_name,
744719 list_name=list_name,
745720 num_args=num_args,
746 inlineCallbacks=inlineCallbacks,
747721 )
722
723 return cast(Callable[[F], _CachedFunction[F]], func)
3939 return len(self.value)
4040
4141
42 class DictionaryCache(object):
42 class DictionaryCache:
4343 """Caches key -> dictionary lookups, supporting caching partial dicts, i.e.
4444 fetching a subset of dictionary keys for a particular key.
4545 """
5252 self.thread = None
5353 # caches_by_name[name] = self.cache
5454
55 class Sentinel(object):
55 class Sentinel:
5656 __slots__ = []
5757
5858 self.sentinel = Sentinel()
2525 SENTINEL = object()
2626
2727
28 class ExpiringCache(object):
28 class ExpiringCache:
2929 def __init__(
3030 self,
3131 cache_name,
189189 return False
190190
191191
192 class _CacheEntry(object):
192 class _CacheEntry:
193193 __slots__ = ["time", "value"]
194194
195195 def __init__(self, time, value):
2929 yield m
3030
3131
32 class _Node(object):
32 class _Node:
3333 __slots__ = ["prev_node", "next_node", "key", "value", "callbacks"]
3434
3535 def __init__(self, prev_node, next_node, key, value, callbacks=set()):
4040 self.callbacks = callbacks
4141
4242
43 class LruCache(object):
43 class LruCache:
4444 """
4545 Least-recently-used cache.
4646 Supports del_multi only if cache_type=TreeCache
2222 logger = logging.getLogger(__name__)
2323
2424
25 class ResponseCache(object):
25 class ResponseCache:
2626 """
2727 This caches a deferred response. Until the deferred completes it will be
2828 returned from the cache. This means that if the client retries the request
22 SENTINEL = object()
33
44
5 class TreeCache(object):
5 class TreeCache:
66 """
77 Tree-based backing store for LruCache. Allows subtrees of data to be deleted
88 efficiently.
8888 yield d
8989
9090
91 class _Entry(object):
91 class _Entry:
9292 __slots__ = ["value"]
9393
9494 def __init__(self, value):
2525 SENTINEL = object()
2626
2727
28 class TTLCache(object):
28 class TTLCache:
2929 """A key/value cache implementation where each entry has its own TTL"""
3030
3131 def __init__(self, cache_name, timer=time.time):
153153
154154
155155 @attr.s(frozen=True, slots=True)
156 class _CacheEntry(object):
156 class _CacheEntry:
157157 """TTLCache entry"""
158158
159159 # expiry_time is the first attribute, so that entries are sorted by expiry.
3333 distributor.fire("user_joined_room", user=user, room_id=room_id)
3434
3535
36 class Distributor(object):
36 class Distributor:
3737 """A central dispatch point for loosely-connected pieces of code to
3838 register, observe, and fire signals.
3939
102102 return succeed(result)
103103
104104
105 class Signal(object):
105 class Signal:
106106 """A Signal is a dispatch point that stores a list of callables as
107107 observers of it.
108108
1919 from synapse.logging.context import make_deferred_yieldable, run_in_background
2020
2121
22 class BackgroundFileConsumer(object):
22 class BackgroundFileConsumer:
2323 """A consumer that writes to a file like object. Supports both push
2424 and pull producers
2525
1313 # limitations under the License.
1414
1515
16 class JsonEncodedObject(object):
16 class JsonEncodedObject:
1717 """ A common base class for defining protocol units that are represented
1818 as JSON.
1919
9292 return wrapper
9393
9494
95 class Measure(object):
95 class Measure:
9696 __slots__ = [
9797 "clock",
9898 "name",
2828 logger = logging.getLogger(__name__)
2929
3030
31 class FederationRateLimiter(object):
31 class FederationRateLimiter:
3232 def __init__(self, clock, config):
3333 """
3434 Args:
5959 return self.ratelimiters[host].ratelimit()
6060
6161
62 class _PerHostRatelimiter(object):
62 class _PerHostRatelimiter:
6363 def __init__(self, clock, config):
6464 """
6565 Args:
113113 )
114114
115115
116 class RetryDestinationLimiter(object):
116 class RetryDestinationLimiter:
117117 def __init__(
118118 self,
119119 destination,
2323 _string_with_symbols = string.digits + string.ascii_letters + ".,;:^&*-_+=#~@"
2424
2525 # https://matrix.org/docs/spec/client_server/r0.6.0#post-matrix-client-r0-register-email-requesttoken
26 # Note: The : character is allowed here for older clients, but will be removed in a
27 # future release. Context: https://github.com/matrix-org/synapse/issues/6766
28 client_secret_regex = re.compile(r"^[0-9a-zA-Z\.\=\_\-\:]+$")
26 client_secret_regex = re.compile(r"^[0-9a-zA-Z\.\=\_\-]+$")
2927
3028 # random_string and random_string_with_symbols are used for a range of things,
3129 # some cryptographically important, some less so. We use SystemRandom to make sure
1313 # limitations under the License.
1414
1515
16 class _Entry(object):
16 class _Entry:
1717 __slots__ = ["end_key", "queue"]
1818
1919 def __init__(self, end_key):
2121 self.queue = []
2222
2323
24 class WheelTimer(object):
24 class WheelTimer:
2525 """Stores arbitrary objects that will be returned after their timers have
2626 expired.
2727 """
3535
3636 # Blacklisted until https://github.com/matrix-org/synapse/pull/6486 lands
3737 Can upload self-signing keys
38
39 # Blacklisted until MSC2753 is implemented
40 Local users can peek into world_readable rooms by room ID
41 We can't peek into rooms with shared history_visibility
42 We can't peek into rooms with invited history_visibility
43 We can't peek into rooms with joined history_visibility
44 Local users can peek by room alias
45 Peeked rooms only turn up in the sync for the device who peeked them
3535 from tests.utils import mock_getRawHeaders, setup_test_homeserver
3636
3737
38 class TestHandlers(object):
38 class TestHandlers:
3939 def __init__(self, hs):
4040 self.auth_handler = synapse.handlers.auth.AuthHandler(hs)
4141
368368 @defer.inlineCallbacks
369369 def test_filter_presence_match(self):
370370 user_filter_json = {"presence": {"types": ["m.*"]}}
371 filter_id = yield self.datastore.add_user_filter(
372 user_localpart=user_localpart, user_filter=user_filter_json
371 filter_id = yield defer.ensureDeferred(
372 self.datastore.add_user_filter(
373 user_localpart=user_localpart, user_filter=user_filter_json
374 )
373375 )
374376 event = MockEvent(sender="@foo:bar", type="m.profile")
375377 events = [event]
387389 def test_filter_presence_no_match(self):
388390 user_filter_json = {"presence": {"types": ["m.*"]}}
389391
390 filter_id = yield self.datastore.add_user_filter(
391 user_localpart=user_localpart + "2", user_filter=user_filter_json
392 filter_id = yield defer.ensureDeferred(
393 self.datastore.add_user_filter(
394 user_localpart=user_localpart + "2", user_filter=user_filter_json
395 )
392396 )
393397 event = MockEvent(
394398 event_id="$asdasd:localhost",
409413 @defer.inlineCallbacks
410414 def test_filter_room_state_match(self):
411415 user_filter_json = {"room": {"state": {"types": ["m.*"]}}}
412 filter_id = yield self.datastore.add_user_filter(
413 user_localpart=user_localpart, user_filter=user_filter_json
416 filter_id = yield defer.ensureDeferred(
417 self.datastore.add_user_filter(
418 user_localpart=user_localpart, user_filter=user_filter_json
419 )
414420 )
415421 event = MockEvent(sender="@foo:bar", type="m.room.topic", room_id="!foo:bar")
416422 events = [event]
427433 @defer.inlineCallbacks
428434 def test_filter_room_state_no_match(self):
429435 user_filter_json = {"room": {"state": {"types": ["m.*"]}}}
430 filter_id = yield self.datastore.add_user_filter(
431 user_localpart=user_localpart, user_filter=user_filter_json
436 filter_id = yield defer.ensureDeferred(
437 self.datastore.add_user_filter(
438 user_localpart=user_localpart, user_filter=user_filter_json
439 )
432440 )
433441 event = MockEvent(
434442 sender="@foo:bar", type="org.matrix.custom.event", room_id="!foo:bar"
464472 def test_add_filter(self):
465473 user_filter_json = {"room": {"state": {"types": ["m.*"]}}}
466474
467 filter_id = yield self.filtering.add_user_filter(
468 user_localpart=user_localpart, user_filter=user_filter_json
475 filter_id = yield defer.ensureDeferred(
476 self.filtering.add_user_filter(
477 user_localpart=user_localpart, user_filter=user_filter_json
478 )
469479 )
470480
471481 self.assertEquals(filter_id, 0)
484494 def test_get_filter(self):
485495 user_filter_json = {"room": {"state": {"types": ["m.*"]}}}
486496
487 filter_id = yield self.datastore.add_user_filter(
488 user_localpart=user_localpart, user_filter=user_filter_json
497 filter_id = yield defer.ensureDeferred(
498 self.datastore.add_user_filter(
499 user_localpart=user_localpart, user_filter=user_filter_json
500 )
489501 )
490502
491503 filter = yield defer.ensureDeferred(
0 # -*- coding: utf-8 -*-
1 # Copyright 2020 The Matrix.org Foundation C.I.C.
2 #
3 # Licensed under the Apache License, Version 2.0 (the "License");
4 # you may not use this file except in compliance with the License.
5 # You may obtain a copy of the License at
6 #
7 # http://www.apache.org/licenses/LICENSE-2.0
8 #
9 # Unless required by applicable law or agreed to in writing, software
10 # distributed under the License is distributed on an "AS IS" BASIS,
11 # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 # See the License for the specific language governing permissions and
13 # limitations under the License.
14
15 import os.path
16 import tempfile
17
18 from synapse.config import ConfigError
19 from synapse.util.stringutils import random_string
20
21 from tests import unittest
22
23
24 class BaseConfigTestCase(unittest.HomeserverTestCase):
25 def prepare(self, reactor, clock, hs):
26 self.hs = hs
27
28 def test_loading_missing_templates(self):
29 # Use a temporary directory that exists on the system, but that isn't likely to
30 # contain template files
31 with tempfile.TemporaryDirectory() as tmp_dir:
32 # Attempt to load an HTML template from our custom template directory
33 template = self.hs.config.read_templates(["sso_error.html"], tmp_dir)[0]
34
35 # If no errors, we should've gotten the default template instead
36
37 # Render the template
38 a_random_string = random_string(5)
39 html_content = template.render({"error_description": a_random_string})
40
41 # Check that our string exists in the template
42 self.assertIn(
43 a_random_string,
44 html_content,
45 "Template file did not contain our test string",
46 )
47
48 def test_loading_custom_templates(self):
49 # Use a temporary directory that exists on the system
50 with tempfile.TemporaryDirectory() as tmp_dir:
51 # Create a temporary bogus template file
52 with tempfile.NamedTemporaryFile(dir=tmp_dir) as tmp_template:
53 # Get temporary file's filename
54 template_filename = os.path.basename(tmp_template.name)
55
56 # Write a custom HTML template
57 contents = b"{{ test_variable }}"
58 tmp_template.write(contents)
59 tmp_template.flush()
60
61 # Attempt to load the template from our custom template directory
62 template = (
63 self.hs.config.read_templates([template_filename], tmp_dir)
64 )[0]
65
66 # Render the template
67 a_random_string = random_string(5)
68 html_content = template.render({"test_variable": a_random_string})
69
70 # Check that our string exists in the template
71 self.assertIn(
72 a_random_string,
73 html_content,
74 "Template file did not contain our test string",
75 )
76
77 def test_loading_template_from_nonexistent_custom_directory(self):
78 with self.assertRaises(ConfigError):
79 self.hs.config.read_templates(
80 ["some_filename.html"], "a_nonexistent_directory"
81 )
4242 from tests.test_utils import make_awaitable
4343
4444
45 class MockPerspectiveServer(object):
45 class MockPerspectiveServer:
4646 def __init__(self):
4747 self.server_name = "mock_server"
4848 self.key = signedjson.key.generate_signing_key(0)
189189
190190 # should fail immediately on an unsigned object
191191 d = _verify_json_for_server(kr, "server9", {}, 0, "test unsigned")
192 self.failureResultOf(d, SynapseError)
192 self.get_failure(d, SynapseError)
193193
194194 # should succeed on a signed object
195195 d = _verify_json_for_server(kr, "server9", json1, 500, "test signed")
220220
221221 # should fail immediately on an unsigned object
222222 d = _verify_json_for_server(kr, "server9", {}, 0, "test unsigned")
223 self.failureResultOf(d, SynapseError)
223 self.get_failure(d, SynapseError)
224224
225225 # should fail on a signed object with a non-zero minimum_valid_until_ms,
226226 # as it tries to refetch the keys and fails.
1414
1515 from mock import Mock
1616
17 from twisted.internet import defer
18
1917 from synapse.api.errors import Codes, SynapseError
2018 from synapse.rest import admin
2119 from synapse.rest.client.v1 import login, room
5957
6058 # Artificially raise the complexity
6159 store = self.hs.get_datastore()
62 store.get_current_state_event_counts = lambda x: defer.succeed(500 * 1.23)
60 store.get_current_state_event_counts = lambda x: make_awaitable(500 * 1.23)
6361
6462 # Get the room complexity again -- make sure it's our artificial value
6563 request, channel = self.make_request(
7876 fed_transport = self.hs.get_federation_transport_client()
7977
8078 # Mock out some things, because we don't want to test the whole join
81 fed_transport.client.get_json = Mock(return_value=make_awaitable({"v1": 9999}))
82 handler.federation_handler.do_invite_join = Mock(
83 return_value=make_awaitable(("", 1))
79 fed_transport.client.get_json = Mock(
80 side_effect=lambda *args, **kwargs: make_awaitable({"v1": 9999})
81 )
82 handler.federation_handler.do_invite_join = Mock(
83 side_effect=lambda *args, **kwargs: make_awaitable(("", 1))
8484 )
8585
8686 d = handler._remote_join(
109109 fed_transport = self.hs.get_federation_transport_client()
110110
111111 # Mock out some things, because we don't want to test the whole join
112 fed_transport.client.get_json = Mock(return_value=make_awaitable({"v1": 9999}))
113 handler.federation_handler.do_invite_join = Mock(
114 return_value=make_awaitable(("", 1))
112 fed_transport.client.get_json = Mock(
113 side_effect=lambda *args, **kwargs: make_awaitable({"v1": 9999})
114 )
115 handler.federation_handler.do_invite_join = Mock(
116 side_effect=lambda *args, **kwargs: make_awaitable(("", 1))
115117 )
116118
117119 d = handler._remote_join(
147149 fed_transport = self.hs.get_federation_transport_client()
148150
149151 # Mock out some things, because we don't want to test the whole join
150 fed_transport.client.get_json = Mock(return_value=make_awaitable(None))
151 handler.federation_handler.do_invite_join = Mock(
152 return_value=make_awaitable(("", 1))
152 fed_transport.client.get_json = Mock(
153 side_effect=lambda *args, **kwargs: make_awaitable(None)
154 )
155 handler.federation_handler.do_invite_join = Mock(
156 side_effect=lambda *args, **kwargs: make_awaitable(("", 1))
153157 )
154158
155159 # Artificially raise the complexity
156 self.hs.get_datastore().get_current_state_event_counts = lambda x: defer.succeed(
160 self.hs.get_datastore().get_current_state_event_counts = lambda x: make_awaitable(
157161 600
158162 )
159163
203207 fed_transport = self.hs.get_federation_transport_client()
204208
205209 # Mock out some things, because we don't want to test the whole join
206 fed_transport.client.get_json = Mock(return_value=make_awaitable({"v1": 9999}))
207 handler.federation_handler.do_invite_join = Mock(
208 return_value=make_awaitable(("", 1))
210 fed_transport.client.get_json = Mock(
211 side_effect=lambda *args, **kwargs: make_awaitable({"v1": 9999})
212 )
213 handler.federation_handler.do_invite_join = Mock(
214 side_effect=lambda *args, **kwargs: make_awaitable(("", 1))
209215 )
210216
211217 d = handler._remote_join(
233239 fed_transport = self.hs.get_federation_transport_client()
234240
235241 # Mock out some things, because we don't want to test the whole join
236 fed_transport.client.get_json = Mock(return_value=make_awaitable({"v1": 9999}))
237 handler.federation_handler.do_invite_join = Mock(
238 return_value=make_awaitable(("", 1))
242 fed_transport.client.get_json = Mock(
243 side_effect=lambda *args, **kwargs: make_awaitable({"v1": 9999})
244 )
245 handler.federation_handler.do_invite_join = Mock(
246 side_effect=lambda *args, **kwargs: make_awaitable(("", 1))
239247 )
240248
241249 d = handler._remote_join(
1414 # limitations under the License.
1515 import logging
1616
17 from parameterized import parameterized
18
1719 from synapse.events import make_event_from_dict
1820 from synapse.federation.federation_server import server_matches_acl_event
1921 from synapse.rest import admin
2022 from synapse.rest.client.v1 import login, room
2123
2224 from tests import unittest
25
26
27 class FederationServerTests(unittest.FederatingHomeserverTestCase):
28
29 servlets = [
30 admin.register_servlets,
31 room.register_servlets,
32 login.register_servlets,
33 ]
34
35 @parameterized.expand([(b"",), (b"foo",), (b'{"limit": Infinity}',)])
36 def test_bad_request(self, query_content):
37 """
38 Querying with bad data returns a reasonable error code.
39 """
40 u1 = self.register_user("u1", "pass")
41 u1_token = self.login("u1", "pass")
42
43 room_1 = self.helper.create_room_as(u1, tok=u1_token)
44 self.inject_room_member(room_1, "@user:other.example.com", "join")
45
46 "/get_missing_events/(?P<room_id>[^/]*)/?"
47
48 request, channel = self.make_request(
49 "POST",
50 "/_matrix/federation/v1/get_missing_events/%s" % (room_1,),
51 query_content,
52 )
53 self.render(request)
54 self.assertEquals(400, channel.code, channel.result)
55 self.assertEqual(channel.json_body["errcode"], "M_NOT_JSON")
2356
2457
2558 class ServerACLsTestCase(unittest.TestCase):
2525
2626 class RoomDirectoryFederationTests(unittest.HomeserverTestCase):
2727 def prepare(self, reactor, clock, homeserver):
28 class Authenticator(object):
28 class Authenticator:
2929 def authenticate_request(self, request, content):
3030 return defer.succeed("otherserver.nottld")
3131
2323 from synapse.handlers.auth import AuthHandler
2424
2525 from tests import unittest
26 from tests.test_utils import make_awaitable
2627 from tests.utils import setup_test_homeserver
2728
2829
29 class AuthHandlers(object):
30 class AuthHandlers:
3031 def __init__(self, hs):
3132 self.auth_handler = AuthHandler(hs)
3233
141142 def test_mau_limits_exceeded_large(self):
142143 self.auth_blocking._limit_usage_by_mau = True
143144 self.hs.get_datastore().get_monthly_active_count = Mock(
144 return_value=defer.succeed(self.large_number_of_users)
145 side_effect=lambda: make_awaitable(self.large_number_of_users)
145146 )
146147
147148 with self.assertRaises(ResourceLimitError):
152153 )
153154
154155 self.hs.get_datastore().get_monthly_active_count = Mock(
155 return_value=defer.succeed(self.large_number_of_users)
156 side_effect=lambda: make_awaitable(self.large_number_of_users)
156157 )
157158 with self.assertRaises(ResourceLimitError):
158159 yield defer.ensureDeferred(
167168
168169 # If not in monthly active cohort
169170 self.hs.get_datastore().get_monthly_active_count = Mock(
170 return_value=defer.succeed(self.auth_blocking._max_mau_value)
171 side_effect=lambda: make_awaitable(self.auth_blocking._max_mau_value)
171172 )
172173 with self.assertRaises(ResourceLimitError):
173174 yield defer.ensureDeferred(
177178 )
178179
179180 self.hs.get_datastore().get_monthly_active_count = Mock(
180 return_value=defer.succeed(self.auth_blocking._max_mau_value)
181 side_effect=lambda: make_awaitable(self.auth_blocking._max_mau_value)
181182 )
182183 with self.assertRaises(ResourceLimitError):
183184 yield defer.ensureDeferred(
187188 )
188189 # If in monthly active cohort
189190 self.hs.get_datastore().user_last_seen_monthly_active = Mock(
190 return_value=defer.succeed(self.hs.get_clock().time_msec())
191 )
192 self.hs.get_datastore().get_monthly_active_count = Mock(
193 return_value=defer.succeed(self.auth_blocking._max_mau_value)
191 side_effect=lambda user_id: make_awaitable(self.hs.get_clock().time_msec())
192 )
193 self.hs.get_datastore().get_monthly_active_count = Mock(
194 side_effect=lambda: make_awaitable(self.auth_blocking._max_mau_value)
194195 )
195196 yield defer.ensureDeferred(
196197 self.auth_handler.get_access_token_for_user_id(
198199 )
199200 )
200201 self.hs.get_datastore().user_last_seen_monthly_active = Mock(
201 return_value=defer.succeed(self.hs.get_clock().time_msec())
202 )
203 self.hs.get_datastore().get_monthly_active_count = Mock(
204 return_value=defer.succeed(self.auth_blocking._max_mau_value)
202 side_effect=lambda user_id: make_awaitable(self.hs.get_clock().time_msec())
203 )
204 self.hs.get_datastore().get_monthly_active_count = Mock(
205 side_effect=lambda: make_awaitable(self.auth_blocking._max_mau_value)
205206 )
206207 yield defer.ensureDeferred(
207208 self.auth_handler.validate_short_term_login_token_and_get_user_id(
214215 self.auth_blocking._limit_usage_by_mau = True
215216
216217 self.hs.get_datastore().get_monthly_active_count = Mock(
217 return_value=defer.succeed(self.small_number_of_users)
218 side_effect=lambda: make_awaitable(self.small_number_of_users)
218219 )
219220 # Ensure does not raise exception
220221 yield defer.ensureDeferred(
224225 )
225226
226227 self.hs.get_datastore().get_monthly_active_count = Mock(
227 return_value=defer.succeed(self.small_number_of_users)
228 side_effect=lambda: make_awaitable(self.small_number_of_users)
228229 )
229230 yield defer.ensureDeferred(
230231 self.auth_handler.validate_short_term_login_token_and_get_user_id(
7474 COOKIE_NAME = b"oidc_session"
7575 COOKIE_PATH = "/_synapse/oidc"
7676
77 MockedMappingProvider = Mock(OidcMappingProvider)
77
78 class TestMappingProvider(OidcMappingProvider):
79 @staticmethod
80 def parse_config(config):
81 return
82
83 def get_remote_user_id(self, userinfo):
84 return userinfo["sub"]
85
86 async def map_user_attributes(self, userinfo, token):
87 return {"localpart": userinfo["username"], "display_name": None}
7888
7989
8090 def simple_async_mock(return_value=None, raises=None):
122132 oidc_config["issuer"] = ISSUER
123133 oidc_config["scopes"] = SCOPES
124134 oidc_config["user_mapping_provider"] = {
125 "module": __name__ + ".MockedMappingProvider"
135 "module": __name__ + ".TestMappingProvider",
126136 }
127137 config["oidc_config"] = oidc_config
128138
373383 self.handler._fetch_userinfo = simple_async_mock(return_value=userinfo)
374384 self.handler._map_userinfo_to_user = simple_async_mock(return_value=user_id)
375385 self.handler._auth_handler.complete_sso_login = simple_async_mock()
376 request = Mock(spec=["args", "getCookie", "addCookie"])
386 request = Mock(
387 spec=["args", "getCookie", "addCookie", "requestHeaders", "getClientIP"]
388 )
377389
378390 code = "code"
379391 state = "state"
380392 nonce = "nonce"
381393 client_redirect_url = "http://client/redirect"
394 user_agent = "Browser"
395 ip_address = "10.0.0.1"
382396 session = self.handler._generate_oidc_session_token(
383397 state=state,
384398 nonce=nonce,
391405 request.args[b"code"] = [code.encode("utf-8")]
392406 request.args[b"state"] = [state.encode("utf-8")]
393407
408 request.requestHeaders = Mock(spec=["getRawHeaders"])
409 request.requestHeaders.getRawHeaders.return_value = [user_agent.encode("ascii")]
410 request.getClientIP.return_value = ip_address
411
394412 yield defer.ensureDeferred(self.handler.handle_oidc_callback(request))
395413
396414 self.handler._auth_handler.complete_sso_login.assert_called_once_with(
398416 )
399417 self.handler._exchange_code.assert_called_once_with(code)
400418 self.handler._parse_id_token.assert_called_once_with(token, nonce=nonce)
401 self.handler._map_userinfo_to_user.assert_called_once_with(userinfo, token)
419 self.handler._map_userinfo_to_user.assert_called_once_with(
420 userinfo, token, user_agent, ip_address
421 )
402422 self.handler._fetch_userinfo.assert_not_called()
403423 self.handler._render_error.assert_not_called()
404424
430450 )
431451 self.handler._exchange_code.assert_called_once_with(code)
432452 self.handler._parse_id_token.assert_not_called()
433 self.handler._map_userinfo_to_user.assert_called_once_with(userinfo, token)
453 self.handler._map_userinfo_to_user.assert_called_once_with(
454 userinfo, token, user_agent, ip_address
455 )
434456 self.handler._fetch_userinfo.assert_called_once_with(token)
435457 self.handler._render_error.assert_not_called()
436458
567589 with self.assertRaises(OidcError) as exc:
568590 yield defer.ensureDeferred(self.handler._exchange_code(code))
569591 self.assertEqual(exc.exception.error, "some_error")
592
593 def test_map_userinfo_to_user(self):
594 """Ensure that mapping the userinfo returned from a provider to an MXID works properly."""
595 userinfo = {
596 "sub": "test_user",
597 "username": "test_user",
598 }
599 # The token doesn't matter with the default user mapping provider.
600 token = {}
601 mxid = self.get_success(
602 self.handler._map_userinfo_to_user(
603 userinfo, token, "user-agent", "10.10.10.10"
604 )
605 )
606 self.assertEqual(mxid, "@test_user:test")
607
608 # Some providers return an integer ID.
609 userinfo = {
610 "sub": 1234,
611 "username": "test_user_2",
612 }
613 mxid = self.get_success(
614 self.handler._map_userinfo_to_user(
615 userinfo, token, "user-agent", "10.10.10.10"
616 )
617 )
618 self.assertEqual(mxid, "@test_user_2:test")
1818 from signedjson.key import generate_signing_key
1919
2020 from synapse.api.constants import EventTypes, Membership, PresenceState
21 from synapse.api.presence import UserPresenceState
2122 from synapse.api.room_versions import KNOWN_ROOM_VERSIONS
2223 from synapse.events.builder import EventBuilder
2324 from synapse.handlers.presence import (
3132 handle_update,
3233 )
3334 from synapse.rest.client.v1 import room
34 from synapse.storage.presence import UserPresenceState
3535 from synapse.types import UserID, get_domain_from_id
3636
3737 from tests import unittest
2727 from tests.utils import setup_test_homeserver
2828
2929
30 class ProfileHandlers(object):
30 class ProfileHandlers:
3131 def __init__(self, hs):
3232 self.profile_handler = MasterProfileHandler(hs)
3333
6363 self.bob = UserID.from_string("@4567:test")
6464 self.alice = UserID.from_string("@alice:remote")
6565
66 yield self.store.create_profile(self.frank.localpart)
66 yield defer.ensureDeferred(self.store.create_profile(self.frank.localpart))
6767
6868 self.handler = hs.get_profile_handler()
6969 self.hs = hs
7070
7171 @defer.inlineCallbacks
7272 def test_get_my_name(self):
73 yield self.store.set_profile_displayname(self.frank.localpart, "Frank")
73 yield defer.ensureDeferred(
74 self.store.set_profile_displayname(self.frank.localpart, "Frank")
75 )
7476
7577 displayname = yield defer.ensureDeferred(
7678 self.handler.get_displayname(self.frank)
103105 )
104106
105107 self.assertEquals(
106 (yield self.store.get_profile_displayname(self.frank.localpart)), "Frank",
108 (
109 yield defer.ensureDeferred(
110 self.store.get_profile_displayname(self.frank.localpart)
111 )
112 ),
113 "Frank",
107114 )
108115
109116 @defer.inlineCallbacks
111118 self.hs.config.enable_set_displayname = False
112119
113120 # Setting displayname for the first time is allowed
114 yield self.store.set_profile_displayname(self.frank.localpart, "Frank")
115
116 self.assertEquals(
117 (yield self.store.get_profile_displayname(self.frank.localpart)), "Frank",
121 yield defer.ensureDeferred(
122 self.store.set_profile_displayname(self.frank.localpart, "Frank")
123 )
124
125 self.assertEquals(
126 (
127 yield defer.ensureDeferred(
128 self.store.get_profile_displayname(self.frank.localpart)
129 )
130 ),
131 "Frank",
118132 )
119133
120134 # Setting displayname a second time is forbidden
156170
157171 @defer.inlineCallbacks
158172 def test_incoming_fed_query(self):
159 yield self.store.create_profile("caroline")
160 yield self.store.set_profile_displayname("caroline", "Caroline")
173 yield defer.ensureDeferred(self.store.create_profile("caroline"))
174 yield defer.ensureDeferred(
175 self.store.set_profile_displayname("caroline", "Caroline")
176 )
161177
162178 response = yield defer.ensureDeferred(
163179 self.query_handlers["profile"](
169185
170186 @defer.inlineCallbacks
171187 def test_get_my_avatar(self):
172 yield self.store.set_profile_avatar_url(
173 self.frank.localpart, "http://my.server/me.png"
188 yield defer.ensureDeferred(
189 self.store.set_profile_avatar_url(
190 self.frank.localpart, "http://my.server/me.png"
191 )
174192 )
175193 avatar_url = yield defer.ensureDeferred(self.handler.get_avatar_url(self.frank))
176194
187205 )
188206
189207 self.assertEquals(
190 (yield self.store.get_profile_avatar_url(self.frank.localpart)),
208 (
209 yield defer.ensureDeferred(
210 self.store.get_profile_avatar_url(self.frank.localpart)
211 )
212 ),
191213 "http://my.server/pic.gif",
192214 )
193215
201223 )
202224
203225 self.assertEquals(
204 (yield self.store.get_profile_avatar_url(self.frank.localpart)),
226 (
227 yield defer.ensureDeferred(
228 self.store.get_profile_avatar_url(self.frank.localpart)
229 )
230 ),
205231 "http://my.server/me.png",
206232 )
207233
210236 self.hs.config.enable_set_avatar_url = False
211237
212238 # Setting displayname for the first time is allowed
213 yield self.store.set_profile_avatar_url(
214 self.frank.localpart, "http://my.server/me.png"
215 )
216
217 self.assertEquals(
218 (yield self.store.get_profile_avatar_url(self.frank.localpart)),
239 yield defer.ensureDeferred(
240 self.store.set_profile_avatar_url(
241 self.frank.localpart, "http://my.server/me.png"
242 )
243 )
244
245 self.assertEquals(
246 (
247 yield defer.ensureDeferred(
248 self.store.get_profile_avatar_url(self.frank.localpart)
249 )
250 ),
219251 "http://my.server/me.png",
220252 )
221253
1414
1515 from mock import Mock
1616
17 from twisted.internet import defer
18
17 from synapse.api.auth import Auth
1918 from synapse.api.constants import UserTypes
2019 from synapse.api.errors import Codes, ResourceLimitError, SynapseError
2120 from synapse.handlers.register import RegistrationHandler
21 from synapse.spam_checker_api import RegistrationBehaviour
2222 from synapse.types import RoomAlias, UserID, create_requester
2323
2424 from tests.test_utils import make_awaitable
2525 from tests.unittest import override_config
26 from tests.utils import mock_getRawHeaders
2627
2728 from .. import unittest
2829
2930
30 class RegistrationHandlers(object):
31 class RegistrationHandlers:
3132 def __init__(self, hs):
3233 self.registration_handler = RegistrationHandler(hs)
3334
9899 def test_get_or_create_user_mau_not_blocked(self):
99100 self.hs.config.limit_usage_by_mau = True
100101 self.store.count_monthly_users = Mock(
101 return_value=defer.succeed(self.hs.config.max_mau_value - 1)
102 side_effect=lambda: make_awaitable(self.hs.config.max_mau_value - 1)
102103 )
103104 # Ensure does not throw exception
104105 self.get_success(self.get_or_create_user(self.requester, "c", "User"))
106107 def test_get_or_create_user_mau_blocked(self):
107108 self.hs.config.limit_usage_by_mau = True
108109 self.store.get_monthly_active_count = Mock(
109 return_value=defer.succeed(self.lots_of_users)
110 side_effect=lambda: make_awaitable(self.lots_of_users)
110111 )
111112 self.get_failure(
112113 self.get_or_create_user(self.requester, "b", "display_name"),
114115 )
115116
116117 self.store.get_monthly_active_count = Mock(
117 return_value=defer.succeed(self.hs.config.max_mau_value)
118 side_effect=lambda: make_awaitable(self.hs.config.max_mau_value)
118119 )
119120 self.get_failure(
120121 self.get_or_create_user(self.requester, "b", "display_name"),
124125 def test_register_mau_blocked(self):
125126 self.hs.config.limit_usage_by_mau = True
126127 self.store.get_monthly_active_count = Mock(
127 return_value=defer.succeed(self.lots_of_users)
128 side_effect=lambda: make_awaitable(self.lots_of_users)
128129 )
129130 self.get_failure(
130131 self.handler.register_user(localpart="local_part"), ResourceLimitError
131132 )
132133
133134 self.store.get_monthly_active_count = Mock(
134 return_value=defer.succeed(self.hs.config.max_mau_value)
135 side_effect=lambda: make_awaitable(self.hs.config.max_mau_value)
135136 )
136137 self.get_failure(
137138 self.handler.register_user(localpart="local_part"), ResourceLimitError
474475 self.handler.register_user(localpart=invalid_user_id), SynapseError
475476 )
476477
478 def test_spam_checker_deny(self):
479 """A spam checker can deny registration, which results in an error."""
480
481 class DenyAll:
482 def check_registration_for_spam(
483 self, email_threepid, username, request_info
484 ):
485 return RegistrationBehaviour.DENY
486
487 # Configure a spam checker that denies all users.
488 spam_checker = self.hs.get_spam_checker()
489 spam_checker.spam_checkers = [DenyAll()]
490
491 self.get_failure(self.handler.register_user(localpart="user"), SynapseError)
492
493 def test_spam_checker_shadow_ban(self):
494 """A spam checker can choose to shadow-ban a user, which allows registration to succeed."""
495
496 class BanAll:
497 def check_registration_for_spam(
498 self, email_threepid, username, request_info
499 ):
500 return RegistrationBehaviour.SHADOW_BAN
501
502 # Configure a spam checker that denies all users.
503 spam_checker = self.hs.get_spam_checker()
504 spam_checker.spam_checkers = [BanAll()]
505
506 user_id = self.get_success(self.handler.register_user(localpart="user"))
507
508 # Get an access token.
509 token = self.macaroon_generator.generate_access_token(user_id)
510 self.get_success(
511 self.store.add_access_token_to_user(
512 user_id=user_id, token=token, device_id=None, valid_until_ms=None
513 )
514 )
515
516 # Ensure the user was marked as shadow-banned.
517 request = Mock(args={})
518 request.args[b"access_token"] = [token.encode("ascii")]
519 request.requestHeaders.getRawHeaders = mock_getRawHeaders()
520 auth = Auth(self.hs)
521 requester = self.get_success(auth.get_user_by_req(request))
522
523 self.assertTrue(requester.shadow_banned)
524
477525 async def get_or_create_user(
478526 self, requester, localpart, displayname, password_hash=None
479527 ):
8080 )
8181 )
8282
83 def get_all_room_state(self):
84 return self.store.db_pool.simple_select_list(
83 async def get_all_room_state(self):
84 return await self.store.db_pool.simple_select_list(
8585 "room_stats_state", None, retcols=("name", "topic", "canonical_alias")
8686 )
8787
255255 # self.handler.notify_new_event()
256256
257257 # We need to let the delta processor advance…
258 self.pump(10 * 60)
258 self.reactor.advance(10 * 60)
259259
260260 # Get the slices! There should be two -- day 1, and day 2.
261261 r = self.get_success(self.store.get_statistics_for_subject("room", room_1, 0))
2020 from twisted.internet import defer
2121
2222 from synapse.api.errors import AuthError
23 from synapse.types import UserID
23 from synapse.types import UserID, create_requester
2424
2525 from tests import unittest
2626 from tests.test_utils import make_awaitable
143143
144144 self.datastore.get_users_in_room = get_users_in_room
145145
146 self.datastore.get_user_directory_stream_pos.return_value = (
146 self.datastore.get_user_directory_stream_pos.side_effect = (
147147 # we deliberately return a non-None stream pos to avoid doing an initial_spam
148 defer.succeed(1)
148 lambda: make_awaitable(1)
149149 )
150150
151151 self.datastore.get_current_state_deltas.return_value = (0, None)
154154 self.datastore.get_new_device_msgs_for_remote = lambda *args, **kargs: make_awaitable(
155155 ([], 0)
156156 )
157 self.datastore.delete_device_msgs_for_remote = lambda *args, **kargs: None
158 self.datastore.set_received_txn_response = lambda *args, **kwargs: defer.succeed(
157 self.datastore.delete_device_msgs_for_remote = lambda *args, **kargs: make_awaitable(
158 None
159 )
160 self.datastore.set_received_txn_response = lambda *args, **kwargs: make_awaitable(
159161 None
160162 )
161163
166168
167169 self.get_success(
168170 self.handler.started_typing(
169 target_user=U_APPLE, auth_user=U_APPLE, room_id=ROOM_ID, timeout=20000
171 target_user=U_APPLE,
172 requester=create_requester(U_APPLE),
173 room_id=ROOM_ID,
174 timeout=20000,
170175 )
171176 )
172177
193198
194199 self.get_success(
195200 self.handler.started_typing(
196 target_user=U_APPLE, auth_user=U_APPLE, room_id=ROOM_ID, timeout=20000
201 target_user=U_APPLE,
202 requester=create_requester(U_APPLE),
203 room_id=ROOM_ID,
204 timeout=20000,
197205 )
198206 )
199207
268276
269277 self.get_success(
270278 self.handler.stopped_typing(
271 target_user=U_APPLE, auth_user=U_APPLE, room_id=ROOM_ID
279 target_user=U_APPLE,
280 requester=create_requester(U_APPLE),
281 room_id=ROOM_ID,
272282 )
273283 )
274284
308318
309319 self.get_success(
310320 self.handler.started_typing(
311 target_user=U_APPLE, auth_user=U_APPLE, room_id=ROOM_ID, timeout=10000
321 target_user=U_APPLE,
322 requester=create_requester(U_APPLE),
323 room_id=ROOM_ID,
324 timeout=10000,
312325 )
313326 )
314327
347360
348361 self.get_success(
349362 self.handler.started_typing(
350 target_user=U_APPLE, auth_user=U_APPLE, room_id=ROOM_ID, timeout=10000
363 target_user=U_APPLE,
364 requester=create_requester(U_APPLE),
365 room_id=ROOM_ID,
366 timeout=10000,
351367 )
352368 )
353369
237237
238238 def test_spam_checker(self):
239239 """
240 A user which fails to the spam checks will not appear in search results.
240 A user which fails the spam checks will not appear in search results.
241241 """
242242 u1 = self.register_user("user1", "pass")
243243 u1_token = self.login(u1, "pass")
268268 # Configure a spam checker that does not filter any users.
269269 spam_checker = self.hs.get_spam_checker()
270270
271 class AllowAll(object):
271 class AllowAll:
272272 def check_username_for_spam(self, user_profile):
273273 # Allow all users.
274274 return False
281281 self.assertEqual(len(s["results"]), 1)
282282
283283 # Configure a spam checker that filters all users.
284 class BlockAll(object):
284 class BlockAll:
285285 def check_username_for_spam(self, user_profile):
286286 # All users are spammy.
287287 return True
132132
133133
134134 @implementer(IOpenSSLServerConnectionCreator)
135 class TestServerTLSConnectionFactory(object):
135 class TestServerTLSConnectionFactory:
136136 """An SSL connection creator which returns connections which present a certificate
137137 signed by our test CA."""
138138
971971 def test_well_known_cache(self):
972972 self.reactor.lookups["testserv"] = "1.2.3.4"
973973
974 fetch_d = self.well_known_resolver.get_well_known(b"testserv")
974 fetch_d = defer.ensureDeferred(
975 self.well_known_resolver.get_well_known(b"testserv")
976 )
975977
976978 # there should be an attempt to connect on port 443 for the .well-known
977979 clients = self.reactor.tcpClients
994996 well_known_server.loseConnection()
995997
996998 # repeat the request: it should hit the cache
997 fetch_d = self.well_known_resolver.get_well_known(b"testserv")
999 fetch_d = defer.ensureDeferred(
1000 self.well_known_resolver.get_well_known(b"testserv")
1001 )
9981002 r = self.successResultOf(fetch_d)
9991003 self.assertEqual(r.delegated_server, b"target-server")
10001004
10021006 self.reactor.pump((1000.0,))
10031007
10041008 # now it should connect again
1005 fetch_d = self.well_known_resolver.get_well_known(b"testserv")
1009 fetch_d = defer.ensureDeferred(
1010 self.well_known_resolver.get_well_known(b"testserv")
1011 )
10061012
10071013 self.assertEqual(len(clients), 1)
10081014 (host, port, client_factory, _timeout, _bindAddress) = clients.pop(0)
10251031
10261032 self.reactor.lookups["testserv"] = "1.2.3.4"
10271033
1028 fetch_d = self.well_known_resolver.get_well_known(b"testserv")
1034 fetch_d = defer.ensureDeferred(
1035 self.well_known_resolver.get_well_known(b"testserv")
1036 )
10291037
10301038 # there should be an attempt to connect on port 443 for the .well-known
10311039 clients = self.reactor.tcpClients
10511059 # another lookup.
10521060 self.reactor.pump((900.0,))
10531061
1054 fetch_d = self.well_known_resolver.get_well_known(b"testserv")
1062 fetch_d = defer.ensureDeferred(
1063 self.well_known_resolver.get_well_known(b"testserv")
1064 )
10551065
10561066 # The resolver may retry a few times, so fonx all requests that come along
10571067 attempts = 0
10811091 self.reactor.pump((10000.0,))
10821092
10831093 # Repated the request, this time it should fail if the lookup fails.
1084 fetch_d = self.well_known_resolver.get_well_known(b"testserv")
1094 fetch_d = defer.ensureDeferred(
1095 self.well_known_resolver.get_well_known(b"testserv")
1096 )
10851097
10861098 clients = self.reactor.tcpClients
10871099 (host, port, client_factory, _timeout, _bindAddress) = clients.pop(0)
12511263
12521264
12531265 @implementer(IPolicyForHTTPS)
1254 class TrustingTLSPolicyForHTTPS(object):
1266 class TrustingTLSPolicyForHTTPS:
12551267 """An IPolicyForHTTPS which checks that the certificate belongs to the
12561268 right server, but doesn't check the certificate chain."""
12571269
1515 from mock import Mock
1616
1717 from netaddr import IPSet
18 from parameterized import parameterized
1819
1920 from twisted.internet import defer
2021 from twisted.internet.defer import TimeoutError
507508 self.assertFalse(conn.disconnecting)
508509
509510 # wait for a while
510 self.pump(120)
511 self.reactor.advance(120)
511512
512513 self.assertTrue(conn.disconnecting)
514
515 @parameterized.expand([(b"",), (b"foo",), (b'{"a": Infinity}',)])
516 def test_json_error(self, return_value):
517 """
518 Test what happens if invalid JSON is returned from the remote endpoint.
519 """
520
521 test_d = defer.ensureDeferred(self.cl.get_json("testserv:8008", "foo/bar"))
522
523 self.pump()
524
525 # Nothing happened yet
526 self.assertNoResult(test_d)
527
528 # Make sure treq is trying to connect
529 clients = self.reactor.tcpClients
530 self.assertEqual(len(clients), 1)
531 (host, port, factory, _timeout, _bindAddress) = clients[0]
532 self.assertEqual(host, "1.2.3.4")
533 self.assertEqual(port, 8008)
534
535 # complete the connection and wire it up to a fake transport
536 protocol = factory.buildProtocol(None)
537 transport = StringTransport()
538 protocol.makeConnection(transport)
539
540 # that should have made it send the request to the transport
541 self.assertRegex(transport.value(), b"^GET /foo/bar")
542 self.assertRegex(transport.value(), b"Host: testserv:8008")
543
544 # Deferred is still without a result
545 self.assertNoResult(test_d)
546
547 # Send it the HTTP response
548 protocol.dataReceived(
549 b"HTTP/1.1 200 OK\r\n"
550 b"Server: Fake\r\n"
551 b"Content-Type: application/json\r\n"
552 b"Content-Length: %i\r\n"
553 b"\r\n"
554 b"%s" % (len(return_value), return_value)
555 )
556
557 self.pump()
558
559 f = self.failureResultOf(test_d)
560 self.assertIsInstance(f.value, ValueError)
0 # -*- coding: utf-8 -*-
1 # Copyright 2020 The Matrix.org Foundation C.I.C.
2 #
3 # Licensed under the Apache License, Version 2.0 (the "License");
4 # you may not use this file except in compliance with the License.
5 # You may obtain a copy of the License at
6 #
7 # http://www.apache.org/licenses/LICENSE-2.0
8 #
9 # Unless required by applicable law or agreed to in writing, software
10 # distributed under the License is distributed on an "AS IS" BASIS,
11 # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 # See the License for the specific language governing permissions and
13 # limitations under the License.
14 import json
15 from io import BytesIO
16
17 from mock import Mock
18
19 from synapse.api.errors import SynapseError
20 from synapse.http.servlet import (
21 parse_json_object_from_request,
22 parse_json_value_from_request,
23 )
24
25 from tests import unittest
26
27
28 def make_request(content):
29 """Make an object that acts enough like a request."""
30 request = Mock(spec=["content"])
31
32 if isinstance(content, dict):
33 content = json.dumps(content).encode("utf8")
34
35 request.content = BytesIO(content)
36 return request
37
38
39 class TestServletUtils(unittest.TestCase):
40 def test_parse_json_value(self):
41 """Basic tests for parse_json_value_from_request."""
42 # Test round-tripping.
43 obj = {"foo": 1}
44 result = parse_json_value_from_request(make_request(obj))
45 self.assertEqual(result, obj)
46
47 # Results don't have to be objects.
48 result = parse_json_value_from_request(make_request(b'["foo"]'))
49 self.assertEqual(result, ["foo"])
50
51 # Test empty.
52 with self.assertRaises(SynapseError):
53 parse_json_value_from_request(make_request(b""))
54
55 result = parse_json_value_from_request(make_request(b""), allow_empty_body=True)
56 self.assertIsNone(result)
57
58 # Invalid UTF-8.
59 with self.assertRaises(SynapseError):
60 parse_json_value_from_request(make_request(b"\xFF\x00"))
61
62 # Invalid JSON.
63 with self.assertRaises(SynapseError):
64 parse_json_value_from_request(make_request(b"foo"))
65
66 with self.assertRaises(SynapseError):
67 parse_json_value_from_request(make_request(b'{"foo": Infinity}'))
68
69 def test_parse_json_object(self):
70 """Basic tests for parse_json_object_from_request."""
71 # Test empty.
72 result = parse_json_object_from_request(
73 make_request(b""), allow_empty_body=True
74 )
75 self.assertEqual(result, {})
76
77 # Test not an object
78 with self.assertRaises(SynapseError):
79 parse_json_object_from_request(make_request(b'["foo"]'))
2828 from tests.unittest import DEBUG, HomeserverTestCase
2929
3030
31 class FakeBeginner(object):
31 class FakeBeginner:
3232 def beginLoggingTo(self, observers, **kwargs):
3333 self.observers = observers
3434
3535
36 class StructuredLoggingTestBase(object):
36 class StructuredLoggingTestBase:
3737 """
3838 Test base that registers a cleanup handler to reset the stdlib log handler
3939 to 'unset'.
3434 # Check that the new user exists with all provided attributes
3535 self.assertEqual(user_id, "@bob:test")
3636 self.assertTrue(access_token)
37 self.assertTrue(self.store.get_user_by_id(user_id))
37 self.assertTrue(self.get_success(self.store.get_user_by_id(user_id)))
3838
3939 # Check that the email was assigned
4040 emails = self.get_success(self.store.user_get_threepids(user_id))
2626
2727
2828 @attr.s
29 class _User(object):
29 class _User:
3030 "Helper wrapper for user ID and access token"
3131 id = attr.ib()
3232 token = attr.ib()
169169 last_stream_ordering = pushers[0]["last_stream_ordering"]
170170
171171 # Advance time a bit, so the pusher will register something has happened
172 self.pump(100)
172 self.pump(10)
173173
174174 # It hasn't succeeded yet, so the stream ordering shouldn't have moved
175175 pushers = self.get_success(
159159 self.check(
160160 "get_unread_event_push_actions_by_room_for_user",
161161 [ROOM_ID, USER_ID_2, event1.event_id],
162 {"highlight_count": 0, "notify_count": 0},
162 {"highlight_count": 0, "unread_count": 0, "notify_count": 0},
163163 )
164164
165165 self.persist(
172172 self.check(
173173 "get_unread_event_push_actions_by_room_for_user",
174174 [ROOM_ID, USER_ID_2, event1.event_id],
175 {"highlight_count": 0, "notify_count": 1},
175 {"highlight_count": 0, "unread_count": 0, "notify_count": 1},
176176 )
177177
178178 self.persist(
187187 self.check(
188188 "get_unread_event_push_actions_by_room_for_user",
189189 [ROOM_ID, USER_ID_2, event1.event_id],
190 {"highlight_count": 1, "notify_count": 2},
190 {"highlight_count": 1, "unread_count": 0, "notify_count": 2},
191191 )
192192
193193 def test_get_rooms_for_user_with_stream_ordering(self):
367367
368368 self.get_success(
369369 self.master_store.add_push_actions_to_staging(
370 event.event_id, {user_id: actions for user_id, actions in push_actions}
370 event.event_id,
371 {user_id: actions for user_id, actions in push_actions},
372 False,
371373 )
372374 )
373375 return event, context
1919 from synapse.events.builder import EventBuilderFactory
2020 from synapse.rest.admin import register_servlets_for_client_rest_resource
2121 from synapse.rest.client.v1 import login, room
22 from synapse.types import UserID
22 from synapse.types import UserID, create_requester
2323
2424 from tests.replication._base import BaseMultiWorkerStreamTestCase
2525 from tests.test_utils import make_awaitable
174174 self.get_success(
175175 typing_handler.started_typing(
176176 target_user=UserID.from_string(user),
177 auth_user=UserID.from_string(user),
177 requester=create_requester(user),
178178 room_id=room,
179179 timeout=20000,
180180 )
1919
2020 from mock import Mock
2121
22 from twisted.internet import defer
23
2422 import synapse.rest.admin
2523 from synapse.api.constants import UserTypes
2624 from synapse.api.errors import HttpResponseException, ResourceLimitError
2826 from synapse.rest.client.v2_alpha import sync
2927
3028 from tests import unittest
29 from tests.test_utils import make_awaitable
3130 from tests.unittest import override_config
3231
3332
337336
338337 # Set monthly active users to the limit
339338 store.get_monthly_active_count = Mock(
340 return_value=defer.succeed(self.hs.config.max_mau_value)
339 side_effect=lambda: make_awaitable(self.hs.config.max_mau_value)
341340 )
342341 # Check that the blocking of monthly active users is working as expected
343342 # The registration of a new user fails due to the limit
591590
592591 # Set monthly active users to the limit
593592 self.store.get_monthly_active_count = Mock(
594 return_value=defer.succeed(self.hs.config.max_mau_value)
593 side_effect=lambda: make_awaitable(self.hs.config.max_mau_value)
595594 )
596595 # Check that the blocking of monthly active users is working as expected
597596 # The registration of a new user fails due to the limit
631630
632631 # Set monthly active users to the limit
633632 self.store.get_monthly_active_count = Mock(
634 return_value=defer.succeed(self.hs.config.max_mau_value)
633 side_effect=lambda: make_awaitable(self.hs.config.max_mau_value)
635634 )
636635 # Check that the blocking of monthly active users is working as expected
637636 # The registration of a new user fails due to the limit
4444 }
4545
4646 self.hs = self.setup_test_homeserver(config=config)
47
4748 return self.hs
4849
4950 def prepare(self, reactor, clock, homeserver):
5051 self.user_id = self.register_user("user", "password")
5152 self.token = self.login("user", "password")
5253
53 def test_retention_state_event(self):
54 """Tests that the server configuration can limit the values a user can set to the
55 room's retention policy.
56 """
57 room_id = self.helper.create_room_as(self.user_id, tok=self.token)
58
59 self.helper.send_state(
60 room_id=room_id,
61 event_type=EventTypes.Retention,
62 body={"max_lifetime": one_day_ms * 4},
63 tok=self.token,
64 expect_code=400,
65 )
66
67 self.helper.send_state(
68 room_id=room_id,
69 event_type=EventTypes.Retention,
70 body={"max_lifetime": one_hour_ms},
71 tok=self.token,
72 expect_code=400,
73 )
54 self.store = self.hs.get_datastore()
55 self.serializer = self.hs.get_event_client_serializer()
56 self.clock = self.hs.get_clock()
7457
7558 def test_retention_event_purged_with_state_event(self):
7659 """Tests that expired events are correctly purged when the room's retention policy
8972
9073 self._test_retention_event_purged(room_id, one_day_ms * 1.5)
9174
75 def test_retention_event_purged_with_state_event_outside_allowed(self):
76 """Tests that the server configuration can override the policy for a room when
77 running the purge jobs.
78 """
79 room_id = self.helper.create_room_as(self.user_id, tok=self.token)
80
81 # Set a max_lifetime higher than the maximum allowed value.
82 self.helper.send_state(
83 room_id=room_id,
84 event_type=EventTypes.Retention,
85 body={"max_lifetime": one_day_ms * 4},
86 tok=self.token,
87 )
88
89 # Check that the event is purged after waiting for the maximum allowed duration
90 # instead of the one specified in the room's policy.
91 self._test_retention_event_purged(room_id, one_day_ms * 1.5)
92
93 # Set a max_lifetime lower than the minimum allowed value.
94 self.helper.send_state(
95 room_id=room_id,
96 event_type=EventTypes.Retention,
97 body={"max_lifetime": one_hour_ms},
98 tok=self.token,
99 )
100
101 # Check that the event is purged after waiting for the minimum allowed duration
102 # instead of the one specified in the room's policy.
103 self._test_retention_event_purged(room_id, one_day_ms * 0.5)
104
92105 def test_retention_event_purged_without_state_event(self):
93106 """Tests that expired events are correctly purged when the room's retention policy
94107 is defined by the server's configuration's default retention policy.
139152 # That event should be the second, not outdated event.
140153 self.assertEqual(filtered_events[0].event_id, valid_event_id, filtered_events)
141154
142 def _test_retention_event_purged(self, room_id, increment):
155 def _test_retention_event_purged(self, room_id: str, increment: float):
156 """Run the following test scenario to test the message retention policy support:
157
158 1. Send event 1
159 2. Increment time by `increment`
160 3. Send event 2
161 4. Increment time by `increment`
162 5. Check that event 1 has been purged
163 6. Check that event 2 has not been purged
164 7. Check that state events that were sent before event 1 aren't purged.
165 The main reason for sending a second event is because currently Synapse won't
166 purge the latest message in a room because it would otherwise result in a lack of
167 forward extremities for this room. It's also a good thing to ensure the purge jobs
168 aren't too greedy and purge messages they shouldn't.
169
170 Args:
171 room_id: The ID of the room to test retention in.
172 increment: The number of milliseconds to advance the clock each time. Must be
173 defined so that events in the room aren't purged if they are `increment`
174 old but are purged if they are `increment * 2` old.
175 """
143176 # Get the create event to, later, check that we can still access it.
144177 message_handler = self.hs.get_message_handler()
145178 create_event = self.get_success(
146179 message_handler.get_room_data(
147 self.user_id, room_id, EventTypes.Create, state_key="", is_guest=False
180 self.user_id, room_id, EventTypes.Create, state_key=""
148181 )
149182 )
150183
155188 expired_event_id = resp.get("event_id")
156189
157190 # Check that we can retrieve the event.
158 expired_event = self.get_event(room_id, expired_event_id)
191 expired_event = self.get_event(expired_event_id)
159192 self.assertEqual(
160193 expired_event.get("content", {}).get("body"), "1", expired_event
161194 )
173206 # one should still be kept.
174207 self.reactor.advance(increment / 1000)
175208
176 # Check that the event has been purged from the database.
177 self.get_event(room_id, expired_event_id, expected_code=404)
178
179 # Check that the event that hasn't been purged can still be retrieved.
180 valid_event = self.get_event(room_id, valid_event_id)
209 # Check that the first event has been purged from the database, i.e. that we
210 # can't retrieve it anymore, because it has expired.
211 self.get_event(expired_event_id, expect_none=True)
212
213 # Check that the event that hasn't expired can still be retrieved.
214 valid_event = self.get_event(valid_event_id)
181215 self.assertEqual(valid_event.get("content", {}).get("body"), "2", valid_event)
182216
183217 # Check that we can still access state events that were sent before the event that
184218 # has been purged.
185219 self.get_event(room_id, create_event.event_id)
186220
187 def get_event(self, room_id, event_id, expected_code=200):
188 url = "/_matrix/client/r0/rooms/%s/event/%s" % (room_id, event_id)
189
190 request, channel = self.make_request("GET", url, access_token=self.token)
191 self.render(request)
192
193 self.assertEqual(channel.code, expected_code, channel.result)
194
195 return channel.json_body
221 def get_event(self, event_id, expect_none=False):
222 event = self.get_success(self.store.get_event(event_id, allow_none=True))
223
224 if expect_none:
225 self.assertIsNone(event)
226 return {}
227
228 self.assertIsNotNone(event)
229
230 time_now = self.clock.time_msec()
231 serialized = self.get_success(self.serializer.serialize_event(event, time_now))
232
233 return serialized
196234
197235
198236 class RetentionNoDefaultPolicyTestCase(unittest.HomeserverTestCase):
0 # Copyright 2020 The Matrix.org Foundation C.I.C.
1 #
2 # Licensed under the Apache License, Version 2.0 (the "License");
3 # you may not use this file except in compliance with the License.
4 # You may obtain a copy of the License at
5 #
6 # http://www.apache.org/licenses/LICENSE-2.0
7 #
8 # Unless required by applicable law or agreed to in writing, software
9 # distributed under the License is distributed on an "AS IS" BASIS,
10 # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
11 # See the License for the specific language governing permissions and
12 # limitations under the License.
13
14 from mock import Mock, patch
15
16 import synapse.rest.admin
17 from synapse.api.constants import EventTypes
18 from synapse.rest.client.v1 import directory, login, profile, room
19 from synapse.rest.client.v2_alpha import room_upgrade_rest_servlet
20
21 from tests import unittest
22
23
24 class _ShadowBannedBase(unittest.HomeserverTestCase):
25 def prepare(self, reactor, clock, homeserver):
26 # Create two users, one of which is shadow-banned.
27 self.banned_user_id = self.register_user("banned", "test")
28 self.banned_access_token = self.login("banned", "test")
29
30 self.store = self.hs.get_datastore()
31
32 self.get_success(
33 self.store.db_pool.simple_update(
34 table="users",
35 keyvalues={"name": self.banned_user_id},
36 updatevalues={"shadow_banned": True},
37 desc="shadow_ban",
38 )
39 )
40
41 self.other_user_id = self.register_user("otheruser", "pass")
42 self.other_access_token = self.login("otheruser", "pass")
43
44
45 # To avoid the tests timing out don't add a delay to "annoy the requester".
46 @patch("random.randint", new=lambda a, b: 0)
47 class RoomTestCase(_ShadowBannedBase):
48 servlets = [
49 synapse.rest.admin.register_servlets_for_client_rest_resource,
50 directory.register_servlets,
51 login.register_servlets,
52 room.register_servlets,
53 room_upgrade_rest_servlet.register_servlets,
54 ]
55
56 def test_invite(self):
57 """Invites from shadow-banned users don't actually get sent."""
58
59 # The create works fine.
60 room_id = self.helper.create_room_as(
61 self.banned_user_id, tok=self.banned_access_token
62 )
63
64 # Inviting the user completes successfully.
65 self.helper.invite(
66 room=room_id,
67 src=self.banned_user_id,
68 tok=self.banned_access_token,
69 targ=self.other_user_id,
70 )
71
72 # But the user wasn't actually invited.
73 invited_rooms = self.get_success(
74 self.store.get_invited_rooms_for_local_user(self.other_user_id)
75 )
76 self.assertEqual(invited_rooms, [])
77
78 def test_invite_3pid(self):
79 """Ensure that a 3PID invite does not attempt to contact the identity server."""
80 identity_handler = self.hs.get_handlers().identity_handler
81 identity_handler.lookup_3pid = Mock(
82 side_effect=AssertionError("This should not get called")
83 )
84
85 # The create works fine.
86 room_id = self.helper.create_room_as(
87 self.banned_user_id, tok=self.banned_access_token
88 )
89
90 # Inviting the user completes successfully.
91 request, channel = self.make_request(
92 "POST",
93 "/rooms/%s/invite" % (room_id,),
94 {"id_server": "test", "medium": "email", "address": "test@test.test"},
95 access_token=self.banned_access_token,
96 )
97 self.render(request)
98 self.assertEquals(200, channel.code, channel.result)
99
100 # This should have raised an error earlier, but double check this wasn't called.
101 identity_handler.lookup_3pid.assert_not_called()
102
103 def test_create_room(self):
104 """Invitations during a room creation should be discarded, but the room still gets created."""
105 # The room creation is successful.
106 request, channel = self.make_request(
107 "POST",
108 "/_matrix/client/r0/createRoom",
109 {"visibility": "public", "invite": [self.other_user_id]},
110 access_token=self.banned_access_token,
111 )
112 self.render(request)
113 self.assertEquals(200, channel.code, channel.result)
114 room_id = channel.json_body["room_id"]
115
116 # But the user wasn't actually invited.
117 invited_rooms = self.get_success(
118 self.store.get_invited_rooms_for_local_user(self.other_user_id)
119 )
120 self.assertEqual(invited_rooms, [])
121
122 # Since a real room was created, the other user should be able to join it.
123 self.helper.join(room_id, self.other_user_id, tok=self.other_access_token)
124
125 # Both users should be in the room.
126 users = self.get_success(self.store.get_users_in_room(room_id))
127 self.assertCountEqual(users, ["@banned:test", "@otheruser:test"])
128
129 def test_message(self):
130 """Messages from shadow-banned users don't actually get sent."""
131
132 room_id = self.helper.create_room_as(
133 self.other_user_id, tok=self.other_access_token
134 )
135
136 # The user should be in the room.
137 self.helper.join(room_id, self.banned_user_id, tok=self.banned_access_token)
138
139 # Sending a message should complete successfully.
140 result = self.helper.send_event(
141 room_id=room_id,
142 type=EventTypes.Message,
143 content={"msgtype": "m.text", "body": "with right label"},
144 tok=self.banned_access_token,
145 )
146 self.assertIn("event_id", result)
147 event_id = result["event_id"]
148
149 latest_events = self.get_success(
150 self.store.get_latest_event_ids_in_room(room_id)
151 )
152 self.assertNotIn(event_id, latest_events)
153
154 def test_upgrade(self):
155 """A room upgrade should fail, but look like it succeeded."""
156
157 # The create works fine.
158 room_id = self.helper.create_room_as(
159 self.banned_user_id, tok=self.banned_access_token
160 )
161
162 request, channel = self.make_request(
163 "POST",
164 "/_matrix/client/r0/rooms/%s/upgrade" % (room_id,),
165 {"new_version": "6"},
166 access_token=self.banned_access_token,
167 )
168 self.render(request)
169 self.assertEquals(200, channel.code, channel.result)
170 # A new room_id should be returned.
171 self.assertIn("replacement_room", channel.json_body)
172
173 new_room_id = channel.json_body["replacement_room"]
174
175 # It doesn't really matter what API we use here, we just want to assert
176 # that the room doesn't exist.
177 summary = self.get_success(self.store.get_room_summary(new_room_id))
178 # The summary should be empty since the room doesn't exist.
179 self.assertEqual(summary, {})
180
181 def test_typing(self):
182 """Typing notifications should not be propagated into the room."""
183 # The create works fine.
184 room_id = self.helper.create_room_as(
185 self.banned_user_id, tok=self.banned_access_token
186 )
187
188 request, channel = self.make_request(
189 "PUT",
190 "/rooms/%s/typing/%s" % (room_id, self.banned_user_id),
191 {"typing": True, "timeout": 30000},
192 access_token=self.banned_access_token,
193 )
194 self.render(request)
195 self.assertEquals(200, channel.code)
196
197 # There should be no typing events.
198 event_source = self.hs.get_event_sources().sources["typing"]
199 self.assertEquals(event_source.get_current_key(), 0)
200
201 # The other user can join and send typing events.
202 self.helper.join(room_id, self.other_user_id, tok=self.other_access_token)
203
204 request, channel = self.make_request(
205 "PUT",
206 "/rooms/%s/typing/%s" % (room_id, self.other_user_id),
207 {"typing": True, "timeout": 30000},
208 access_token=self.other_access_token,
209 )
210 self.render(request)
211 self.assertEquals(200, channel.code)
212
213 # These appear in the room.
214 self.assertEquals(event_source.get_current_key(), 1)
215 events = self.get_success(
216 event_source.get_new_events(from_key=0, room_ids=[room_id])
217 )
218 self.assertEquals(
219 events[0],
220 [
221 {
222 "type": "m.typing",
223 "room_id": room_id,
224 "content": {"user_ids": [self.other_user_id]},
225 }
226 ],
227 )
228
229
230 # To avoid the tests timing out don't add a delay to "annoy the requester".
231 @patch("random.randint", new=lambda a, b: 0)
232 class ProfileTestCase(_ShadowBannedBase):
233 servlets = [
234 synapse.rest.admin.register_servlets_for_client_rest_resource,
235 login.register_servlets,
236 profile.register_servlets,
237 room.register_servlets,
238 ]
239
240 def test_displayname(self):
241 """Profile changes should succeed, but don't end up in a room."""
242 original_display_name = "banned"
243 new_display_name = "new name"
244
245 # Join a room.
246 room_id = self.helper.create_room_as(
247 self.banned_user_id, tok=self.banned_access_token
248 )
249
250 # The update should succeed.
251 request, channel = self.make_request(
252 "PUT",
253 "/_matrix/client/r0/profile/%s/displayname" % (self.banned_user_id,),
254 {"displayname": new_display_name},
255 access_token=self.banned_access_token,
256 )
257 self.render(request)
258 self.assertEquals(200, channel.code, channel.result)
259 self.assertEqual(channel.json_body, {})
260
261 # The user's display name should be updated.
262 request, channel = self.make_request(
263 "GET", "/profile/%s/displayname" % (self.banned_user_id,)
264 )
265 self.render(request)
266 self.assertEqual(channel.code, 200, channel.result)
267 self.assertEqual(channel.json_body["displayname"], new_display_name)
268
269 # But the display name in the room should not be.
270 message_handler = self.hs.get_message_handler()
271 event = self.get_success(
272 message_handler.get_room_data(
273 self.banned_user_id, room_id, "m.room.member", self.banned_user_id,
274 )
275 )
276 self.assertEqual(
277 event.content, {"membership": "join", "displayname": original_display_name}
278 )
279
280 def test_room_displayname(self):
281 """Changes to state events for a room should be processed, but not end up in the room."""
282 original_display_name = "banned"
283 new_display_name = "new name"
284
285 # Join a room.
286 room_id = self.helper.create_room_as(
287 self.banned_user_id, tok=self.banned_access_token
288 )
289
290 # The update should succeed.
291 request, channel = self.make_request(
292 "PUT",
293 "/_matrix/client/r0/rooms/%s/state/m.room.member/%s"
294 % (room_id, self.banned_user_id),
295 {"membership": "join", "displayname": new_display_name},
296 access_token=self.banned_access_token,
297 )
298 self.render(request)
299 self.assertEquals(200, channel.code, channel.result)
300 self.assertIn("event_id", channel.json_body)
301
302 # The display name in the room should not be changed.
303 message_handler = self.hs.get_message_handler()
304 event = self.get_success(
305 message_handler.get_room_data(
306 self.banned_user_id, room_id, "m.room.member", self.banned_user_id,
307 )
308 )
309 self.assertEqual(
310 event.content, {"membership": "join", "displayname": original_display_name}
311 )
1818 from tests import unittest
1919
2020
21 class ThirdPartyRulesTestModule(object):
21 class ThirdPartyRulesTestModule:
2222 def __init__(self, config):
2323 pass
2424
6161 "identifier": {"type": "m.id.user", "user": "kermit" + str(i)},
6262 "password": "monkey",
6363 }
64 request_data = json.dumps(params)
65 request, channel = self.make_request(b"POST", LOGIN_URL, request_data)
64 request, channel = self.make_request(b"POST", LOGIN_URL, params)
6665 self.render(request)
6766
6867 if i == 5:
7574 # than 1min.
7675 self.assertTrue(retry_after_ms < 6000)
7776
78 self.reactor.advance(retry_after_ms / 1000.0)
77 self.reactor.advance(retry_after_ms / 1000.0 + 1.0)
7978
8079 params = {
8180 "type": "m.login.password",
8281 "identifier": {"type": "m.id.user", "user": "kermit" + str(i)},
8382 "password": "monkey",
8483 }
85 request_data = json.dumps(params)
8684 request, channel = self.make_request(b"POST", LOGIN_URL, params)
8785 self.render(request)
8886
110108 "identifier": {"type": "m.id.user", "user": "kermit"},
111109 "password": "monkey",
112110 }
113 request_data = json.dumps(params)
114 request, channel = self.make_request(b"POST", LOGIN_URL, request_data)
111 request, channel = self.make_request(b"POST", LOGIN_URL, params)
115112 self.render(request)
116113
117114 if i == 5:
131128 "identifier": {"type": "m.id.user", "user": "kermit"},
132129 "password": "monkey",
133130 }
134 request_data = json.dumps(params)
135131 request, channel = self.make_request(b"POST", LOGIN_URL, params)
136132 self.render(request)
137133
159155 "identifier": {"type": "m.id.user", "user": "kermit"},
160156 "password": "notamonkey",
161157 }
162 request_data = json.dumps(params)
163 request, channel = self.make_request(b"POST", LOGIN_URL, request_data)
158 request, channel = self.make_request(b"POST", LOGIN_URL, params)
164159 self.render(request)
165160
166161 if i == 5:
173168 # than 1min.
174169 self.assertTrue(retry_after_ms < 6000)
175170
176 self.reactor.advance(retry_after_ms / 1000.0)
171 self.reactor.advance(retry_after_ms / 1000.0 + 1.0)
177172
178173 params = {
179174 "type": "m.login.password",
180175 "identifier": {"type": "m.id.user", "user": "kermit"},
181176 "password": "notamonkey",
182177 }
183 request_data = json.dumps(params)
184178 request, channel = self.make_request(b"POST", LOGIN_URL, params)
185179 self.render(request)
186180
683683 ]
684684
685685 @unittest.override_config(
686 {"rc_joins": {"local": {"per_second": 3, "burst_count": 3}}}
686 {"rc_joins": {"local": {"per_second": 0.5, "burst_count": 3}}}
687687 )
688688 def test_join_local_ratelimit(self):
689689 """Tests that local joins are actually rate-limited."""
690 for i in range(5):
690 for i in range(3):
691691 self.helper.create_room_as(self.user_id)
692692
693693 self.helper.create_room_as(self.user_id, expect_code=429)
694694
695695 @unittest.override_config(
696 {"rc_joins": {"local": {"per_second": 3, "burst_count": 3}}}
696 {"rc_joins": {"local": {"per_second": 0.5, "burst_count": 3}}}
697697 )
698698 def test_join_local_ratelimit_profile_change(self):
699699 """Tests that sending a profile update into all of the user's joined rooms isn't
700700 rate-limited by the rate-limiter on joins."""
701701
702 # Create and join more rooms than the rate-limiting config allows in a second.
702 # Create and join as many rooms as the rate-limiting config allows in a second.
703703 room_ids = [
704704 self.helper.create_room_as(self.user_id),
705705 self.helper.create_room_as(self.user_id),
706706 self.helper.create_room_as(self.user_id),
707707 ]
708 self.reactor.advance(1)
709 room_ids = room_ids + [
710 self.helper.create_room_as(self.user_id),
711 self.helper.create_room_as(self.user_id),
712 self.helper.create_room_as(self.user_id),
713 ]
708 # Let some time for the rate-limiter to forget about our multi-join.
709 self.reactor.advance(2)
710 # Add one to make sure we're joined to more rooms than the config allows us to
711 # join in a second.
712 room_ids.append(self.helper.create_room_as(self.user_id))
714713
715714 # Create a profile for the user, since it hasn't been done on registration.
716715 store = self.hs.get_datastore()
717 store.create_profile(UserID.from_string(self.user_id).localpart)
716 self.get_success(
717 store.create_profile(UserID.from_string(self.user_id).localpart)
718 )
718719
719720 # Update the display name for the user.
720721 path = "/_matrix/client/r0/profile/%s/displayname" % self.user_id
737738 self.assertEquals(channel.json_body["displayname"], "John Doe")
738739
739740 @unittest.override_config(
740 {"rc_joins": {"local": {"per_second": 3, "burst_count": 3}}}
741 {"rc_joins": {"local": {"per_second": 0.5, "burst_count": 3}}}
741742 )
742743 def test_join_local_ratelimit_idempotent(self):
743744 """Tests that the room join endpoints remain idempotent despite rate-limiting
753754 for path in paths_to_test:
754755 # Make sure we send more requests than the rate-limiting config would allow
755756 # if all of these requests ended up joining the user to a room.
756 for i in range(6):
757 for i in range(4):
757758 request, channel = self.make_request("POST", path % room_id, {})
758759 self.render(request)
759760 self.assertEquals(channel.code, 200)
2929
3030
3131 @attr.s
32 class RestHelper(object):
32 class RestHelper:
3333 """Contains extra helper functions to quickly and clearly perform a given
3434 REST action, which isn't the focus of the test.
3535 """
1111 # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
1212 # See the License for the specific language governing permissions and
1313 # limitations under the License.
14
15 from twisted.internet import defer
1416
1517 from synapse.api.errors import Codes
1618 from synapse.rest.client.v2_alpha import filter
7274 self.assertEquals(channel.json_body["errcode"], Codes.FORBIDDEN)
7375
7476 def test_get_filter(self):
75 filter_id = self.filtering.add_user_filter(
76 user_localpart="apple", user_filter=self.EXAMPLE_FILTER
77 filter_id = defer.ensureDeferred(
78 self.filtering.add_user_filter(
79 user_localpart="apple", user_filter=self.EXAMPLE_FILTER
80 )
7781 )
7882 self.reactor.advance(1)
7983 filter_id = filter_id.result
159159 else:
160160 self.assertEquals(channel.result["code"], b"200", channel.result)
161161
162 self.reactor.advance(retry_after_ms / 1000.0)
162 self.reactor.advance(retry_after_ms / 1000.0 + 1.0)
163163
164164 request, channel = self.make_request(b"POST", self.url + b"?kind=guest", b"{}")
165165 self.render(request)
185185 else:
186186 self.assertEquals(channel.result["code"], b"200", channel.result)
187187
188 self.reactor.advance(retry_after_ms / 1000.0)
188 self.reactor.advance(retry_after_ms / 1000.0 + 1.0)
189189
190190 request, channel = self.make_request(b"POST", self.url + b"?kind=guest", b"{}")
191191 self.render(request)
0 # -*- coding: utf-8 -*-
1 # Copyright 2020 Half-Shot
2 #
3 # Licensed under the Apache License, Version 2.0 (the "License");
4 # you may not use this file except in compliance with the License.
5 # You may obtain a copy of the License at
6 #
7 # http://www.apache.org/licenses/LICENSE-2.0
8 #
9 # Unless required by applicable law or agreed to in writing, software
10 # distributed under the License is distributed on an "AS IS" BASIS,
11 # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 # See the License for the specific language governing permissions and
13 # limitations under the License.
14 import synapse.rest.admin
15 from synapse.rest.client.v1 import login, room
16 from synapse.rest.client.v2_alpha import shared_rooms
17
18 from tests import unittest
19
20
21 class UserSharedRoomsTest(unittest.HomeserverTestCase):
22 """
23 Tests the UserSharedRoomsServlet.
24 """
25
26 servlets = [
27 login.register_servlets,
28 synapse.rest.admin.register_servlets_for_client_rest_resource,
29 room.register_servlets,
30 shared_rooms.register_servlets,
31 ]
32
33 def make_homeserver(self, reactor, clock):
34 config = self.default_config()
35 config["update_user_directory"] = True
36 return self.setup_test_homeserver(config=config)
37
38 def prepare(self, reactor, clock, hs):
39 self.store = hs.get_datastore()
40 self.handler = hs.get_user_directory_handler()
41
42 def _get_shared_rooms(self, token, other_user):
43 request, channel = self.make_request(
44 "GET",
45 "/_matrix/client/unstable/uk.half-shot.msc2666/user/shared_rooms/%s"
46 % other_user,
47 access_token=token,
48 )
49 self.render(request)
50 return request, channel
51
52 def test_shared_room_list_public(self):
53 """
54 A room should show up in the shared list of rooms between two users
55 if it is public.
56 """
57 u1 = self.register_user("user1", "pass")
58 u1_token = self.login(u1, "pass")
59 u2 = self.register_user("user2", "pass")
60 u2_token = self.login(u2, "pass")
61
62 room = self.helper.create_room_as(u1, is_public=True, tok=u1_token)
63 self.helper.invite(room, src=u1, targ=u2, tok=u1_token)
64 self.helper.join(room, user=u2, tok=u2_token)
65
66 request, channel = self._get_shared_rooms(u1_token, u2)
67 self.assertEquals(200, channel.code, channel.result)
68 self.assertEquals(len(channel.json_body["joined"]), 1)
69 self.assertEquals(channel.json_body["joined"][0], room)
70
71 def test_shared_room_list_private(self):
72 """
73 A room should show up in the shared list of rooms between two users
74 if it is private.
75 """
76 u1 = self.register_user("user1", "pass")
77 u1_token = self.login(u1, "pass")
78 u2 = self.register_user("user2", "pass")
79 u2_token = self.login(u2, "pass")
80
81 room = self.helper.create_room_as(u1, is_public=False, tok=u1_token)
82 self.helper.invite(room, src=u1, targ=u2, tok=u1_token)
83 self.helper.join(room, user=u2, tok=u2_token)
84
85 request, channel = self._get_shared_rooms(u1_token, u2)
86 self.assertEquals(200, channel.code, channel.result)
87 self.assertEquals(len(channel.json_body["joined"]), 1)
88 self.assertEquals(channel.json_body["joined"][0], room)
89
90 def test_shared_room_list_mixed(self):
91 """
92 The shared room list between two users should contain both public and private
93 rooms.
94 """
95 u1 = self.register_user("user1", "pass")
96 u1_token = self.login(u1, "pass")
97 u2 = self.register_user("user2", "pass")
98 u2_token = self.login(u2, "pass")
99
100 room_public = self.helper.create_room_as(u1, is_public=True, tok=u1_token)
101 room_private = self.helper.create_room_as(u2, is_public=False, tok=u2_token)
102 self.helper.invite(room_public, src=u1, targ=u2, tok=u1_token)
103 self.helper.invite(room_private, src=u2, targ=u1, tok=u2_token)
104 self.helper.join(room_public, user=u2, tok=u2_token)
105 self.helper.join(room_private, user=u1, tok=u1_token)
106
107 request, channel = self._get_shared_rooms(u1_token, u2)
108 self.assertEquals(200, channel.code, channel.result)
109 self.assertEquals(len(channel.json_body["joined"]), 2)
110 self.assertTrue(room_public in channel.json_body["joined"])
111 self.assertTrue(room_private in channel.json_body["joined"])
112
113 def test_shared_room_list_after_leave(self):
114 """
115 A room should no longer be considered shared if the other
116 user has left it.
117 """
118 u1 = self.register_user("user1", "pass")
119 u1_token = self.login(u1, "pass")
120 u2 = self.register_user("user2", "pass")
121 u2_token = self.login(u2, "pass")
122
123 room = self.helper.create_room_as(u1, is_public=True, tok=u1_token)
124 self.helper.invite(room, src=u1, targ=u2, tok=u1_token)
125 self.helper.join(room, user=u2, tok=u2_token)
126
127 # Assert user directory is not empty
128 request, channel = self._get_shared_rooms(u1_token, u2)
129 self.assertEquals(200, channel.code, channel.result)
130 self.assertEquals(len(channel.json_body["joined"]), 1)
131 self.assertEquals(channel.json_body["joined"][0], room)
132
133 self.helper.leave(room, user=u1, tok=u1_token)
134
135 request, channel = self._get_shared_rooms(u2_token, u1)
136 self.assertEquals(200, channel.code, channel.result)
137 self.assertEquals(len(channel.json_body["joined"]), 0)
1515 import json
1616
1717 import synapse.rest.admin
18 from synapse.api.constants import EventContentFields, EventTypes
18 from synapse.api.constants import EventContentFields, EventTypes, RelationTypes
1919 from synapse.rest.client.v1 import login, room
20 from synapse.rest.client.v2_alpha import sync
20 from synapse.rest.client.v2_alpha import read_marker, sync
2121
2222 from tests import unittest
2323 from tests.server import TimedOutException
323323 "GET", sync_url % (access_token, next_batch)
324324 )
325325 self.assertRaises(TimedOutException, self.render, request)
326
327
328 class UnreadMessagesTestCase(unittest.HomeserverTestCase):
329 servlets = [
330 synapse.rest.admin.register_servlets,
331 login.register_servlets,
332 read_marker.register_servlets,
333 room.register_servlets,
334 sync.register_servlets,
335 ]
336
337 def prepare(self, reactor, clock, hs):
338 self.url = "/sync?since=%s"
339 self.next_batch = "s0"
340
341 # Register the first user (used to check the unread counts).
342 self.user_id = self.register_user("kermit", "monkey")
343 self.tok = self.login("kermit", "monkey")
344
345 # Create the room we'll check unread counts for.
346 self.room_id = self.helper.create_room_as(self.user_id, tok=self.tok)
347
348 # Register the second user (used to send events to the room).
349 self.user2 = self.register_user("kermit2", "monkey")
350 self.tok2 = self.login("kermit2", "monkey")
351
352 # Change the power levels of the room so that the second user can send state
353 # events.
354 self.helper.send_state(
355 self.room_id,
356 EventTypes.PowerLevels,
357 {
358 "users": {self.user_id: 100, self.user2: 100},
359 "users_default": 0,
360 "events": {
361 "m.room.name": 50,
362 "m.room.power_levels": 100,
363 "m.room.history_visibility": 100,
364 "m.room.canonical_alias": 50,
365 "m.room.avatar": 50,
366 "m.room.tombstone": 100,
367 "m.room.server_acl": 100,
368 "m.room.encryption": 100,
369 },
370 "events_default": 0,
371 "state_default": 50,
372 "ban": 50,
373 "kick": 50,
374 "redact": 50,
375 "invite": 0,
376 },
377 tok=self.tok,
378 )
379
380 def test_unread_counts(self):
381 """Tests that /sync returns the right value for the unread count (MSC2654)."""
382
383 # Check that our own messages don't increase the unread count.
384 self.helper.send(self.room_id, "hello", tok=self.tok)
385 self._check_unread_count(0)
386
387 # Join the new user and check that this doesn't increase the unread count.
388 self.helper.join(room=self.room_id, user=self.user2, tok=self.tok2)
389 self._check_unread_count(0)
390
391 # Check that the new user sending a message increases our unread count.
392 res = self.helper.send(self.room_id, "hello", tok=self.tok2)
393 self._check_unread_count(1)
394
395 # Send a read receipt to tell the server we've read the latest event.
396 body = json.dumps({"m.read": res["event_id"]}).encode("utf8")
397 request, channel = self.make_request(
398 "POST",
399 "/rooms/%s/read_markers" % self.room_id,
400 body,
401 access_token=self.tok,
402 )
403 self.render(request)
404 self.assertEqual(channel.code, 200, channel.json_body)
405
406 # Check that the unread counter is back to 0.
407 self._check_unread_count(0)
408
409 # Check that room name changes increase the unread counter.
410 self.helper.send_state(
411 self.room_id, "m.room.name", {"name": "my super room"}, tok=self.tok2,
412 )
413 self._check_unread_count(1)
414
415 # Check that room topic changes increase the unread counter.
416 self.helper.send_state(
417 self.room_id, "m.room.topic", {"topic": "welcome!!!"}, tok=self.tok2,
418 )
419 self._check_unread_count(2)
420
421 # Check that encrypted messages increase the unread counter.
422 self.helper.send_event(self.room_id, EventTypes.Encrypted, {}, tok=self.tok2)
423 self._check_unread_count(3)
424
425 # Check that custom events with a body increase the unread counter.
426 self.helper.send_event(
427 self.room_id, "org.matrix.custom_type", {"body": "hello"}, tok=self.tok2,
428 )
429 self._check_unread_count(4)
430
431 # Check that edits don't increase the unread counter.
432 self.helper.send_event(
433 room_id=self.room_id,
434 type=EventTypes.Message,
435 content={
436 "body": "hello",
437 "msgtype": "m.text",
438 "m.relates_to": {"rel_type": RelationTypes.REPLACE},
439 },
440 tok=self.tok2,
441 )
442 self._check_unread_count(4)
443
444 # Check that notices don't increase the unread counter.
445 self.helper.send_event(
446 room_id=self.room_id,
447 type=EventTypes.Message,
448 content={"body": "hello", "msgtype": "m.notice"},
449 tok=self.tok2,
450 )
451 self._check_unread_count(4)
452
453 # Check that tombstone events changes increase the unread counter.
454 self.helper.send_state(
455 self.room_id,
456 EventTypes.Tombstone,
457 {"replacement_room": "!someroom:test"},
458 tok=self.tok2,
459 )
460 self._check_unread_count(5)
461
462 def _check_unread_count(self, expected_count: True):
463 """Syncs and compares the unread count with the expected value."""
464
465 request, channel = self.make_request(
466 "GET", self.url % self.next_batch, access_token=self.tok,
467 )
468 self.render(request)
469
470 self.assertEqual(channel.code, 200, channel.json_body)
471
472 room_entry = channel.json_body["rooms"]["join"][self.room_id]
473 self.assertEqual(
474 room_entry["org.matrix.msc2654.unread_count"], expected_count, room_entry,
475 )
476
477 # Store the next batch for the next request.
478 self.next_batch = channel.json_body["next_batch"]
3131
3232
3333 @attr.s
34 class FakeResponse(object):
34 class FakeResponse:
3535 version = attr.ib()
3636 code = attr.ib()
3737 phrase = attr.ib()
4242 @property
4343 def request(self):
4444 @attr.s
45 class FakeTransport(object):
45 class FakeTransport:
4646 absoluteURI = self.absoluteURI
4747
4848 return FakeTransport()
110110
111111 self.lookups = {}
112112
113 class Resolver(object):
113 class Resolver:
114114 def resolveHostName(
115115 _self,
116116 resolutionReceiver,
3434
3535
3636 @attr.s
37 class FakeChannel(object):
37 class FakeChannel:
3838 """
3939 A fake Twisted Web Channel (the part that interfaces with the
4040 wire).
241241 lookups = self.lookups = {}
242242
243243 @implementer(IResolverSimple)
244 class FakeResolver(object):
244 class FakeResolver:
245245 def getHostByName(self, name, timeout=None):
246246 if name not in lookups:
247247 return fail(DNSLookupError("OH NO: unknown %s" % (name,)))
370370
371371
372372 @attr.s(cmp=False)
373 class FakeTransport(object):
373 class FakeTransport:
374374 """
375375 A twisted.internet.interfaces.ITransport implementation which sends all its data
376376 straight into an IProtocol object: it exists to connect two IProtocols together.
6666 raise Exception("Failed to find reference to ResourceLimitsServerNotices")
6767
6868 self._rlsn._store.user_last_seen_monthly_active = Mock(
69 return_value=defer.succeed(1000)
69 side_effect=lambda user_id: make_awaitable(1000)
7070 )
7171 self._rlsn._server_notices_manager.send_notice = Mock(
7272 return_value=defer.succeed(Mock())
103103 type=EventTypes.Message, content={"msgtype": ServerNoticeMsgType}
104104 )
105105 self._rlsn._store.get_events = Mock(
106 return_value=defer.succeed({"123": mock_event})
106 return_value=make_awaitable({"123": mock_event})
107107 )
108108 self.get_success(self._rlsn.maybe_send_server_notice_to_user(self.user_id))
109109 # Would be better to check the content, but once == remove blocking event
121121 type=EventTypes.Message, content={"msgtype": ServerNoticeMsgType}
122122 )
123123 self._rlsn._store.get_events = Mock(
124 return_value=defer.succeed({"123": mock_event})
124 return_value=make_awaitable({"123": mock_event})
125125 )
126126
127127 self.get_success(self._rlsn.maybe_send_server_notice_to_user(self.user_id))
157157 """
158158 self._rlsn._auth.check_auth_blocking = Mock(return_value=defer.succeed(None))
159159 self._rlsn._store.user_last_seen_monthly_active = Mock(
160 return_value=defer.succeed(None)
160 side_effect=lambda user_id: make_awaitable(None)
161161 )
162162 self.get_success(self._rlsn.maybe_send_server_notice_to_user(self.user_id))
163163
216216 type=EventTypes.Message, content={"msgtype": ServerNoticeMsgType}
217217 )
218218 self._rlsn._store.get_events = Mock(
219 return_value=defer.succeed({"123": mock_event})
219 return_value=make_awaitable({"123": mock_event})
220220 )
221221 self.get_success(self._rlsn.maybe_send_server_notice_to_user(self.user_id))
222222
260260 self.user_id = "@user_id:test"
261261
262262 def test_server_notice_only_sent_once(self):
263 self.store.get_monthly_active_count = Mock(return_value=defer.succeed(1000))
263 self.store.get_monthly_active_count = Mock(
264 side_effect=lambda: make_awaitable(1000)
265 )
264266
265267 self.store.user_last_seen_monthly_active = Mock(
266 return_value=defer.succeed(1000)
268 side_effect=lambda user_id: make_awaitable(1000)
267269 )
268270
269271 # Call the function multiple times to ensure we only send the notice once
4848 return defer.succeed(None)
4949
5050
51 class FakeEvent(object):
51 class FakeEvent:
5252 """A fake event we use as a convenience.
5353
5454 NOTE: Again as a convenience we use "node_ids" rather than event_ids to
594594
595595
596596 @attr.s
597 class TestStateResolutionStore(object):
597 class TestStateResolutionStore:
598598 event_map = attr.ib()
599599
600600 def get_events(self, event_ids, allow_rejected=False):
9898 class CacheDecoratorTestCase(unittest.HomeserverTestCase):
9999 @defer.inlineCallbacks
100100 def test_passthrough(self):
101 class A(object):
101 class A:
102102 @cached()
103103 def func(self, key):
104104 return key
112112 def test_hit(self):
113113 callcount = [0]
114114
115 class A(object):
115 class A:
116116 @cached()
117117 def func(self, key):
118118 callcount[0] += 1
130130 def test_invalidate(self):
131131 callcount = [0]
132132
133 class A(object):
133 class A:
134134 @cached()
135135 def func(self, key):
136136 callcount[0] += 1
148148 self.assertEquals(callcount[0], 2)
149149
150150 def test_invalidate_missing(self):
151 class A(object):
151 class A:
152152 @cached()
153153 def func(self, key):
154154 return key
159159 def test_max_entries(self):
160160 callcount = [0]
161161
162 class A(object):
162 class A:
163163 @cached(max_entries=10)
164164 def func(self, key):
165165 callcount[0] += 1
186186
187187 d = defer.succeed(123)
188188
189 class A(object):
189 class A:
190190 @cached()
191191 def func(self, key):
192192 callcount[0] += 1
204204 callcount = [0]
205205 callcount2 = [0]
206206
207 class A(object):
207 class A:
208208 @cached()
209209 def func(self, key):
210210 callcount[0] += 1
237237 callcount = [0]
238238 callcount2 = [0]
239239
240 class A(object):
240 class A:
241241 @cached(max_entries=2)
242242 def func(self, key):
243243 callcount[0] += 1
274274 callcount = [0]
275275 callcount2 = [0]
276276
277 class A(object):
277 class A:
278278 @cached()
279279 def func(self, key):
280280 callcount[0] += 1
3030 )
3131
3232 from tests import unittest
33 from tests.test_utils import make_awaitable
3334 from tests.utils import setup_test_homeserver
3435
3536
206207 @defer.inlineCallbacks
207208 def test_set_appservices_state_down(self):
208209 service = Mock(id=self.as_list[1]["id"])
209 yield self.store.set_appservice_state(service, ApplicationServiceState.DOWN)
210 yield defer.ensureDeferred(
211 self.store.set_appservice_state(service, ApplicationServiceState.DOWN)
212 )
210213 rows = yield self.db_pool.runQuery(
211214 self.engine.convert_param_style(
212215 "SELECT as_id FROM application_services_state WHERE state=?"
218221 @defer.inlineCallbacks
219222 def test_set_appservices_state_multiple_up(self):
220223 service = Mock(id=self.as_list[1]["id"])
221 yield self.store.set_appservice_state(service, ApplicationServiceState.UP)
222 yield self.store.set_appservice_state(service, ApplicationServiceState.DOWN)
223 yield self.store.set_appservice_state(service, ApplicationServiceState.UP)
224 yield defer.ensureDeferred(
225 self.store.set_appservice_state(service, ApplicationServiceState.UP)
226 )
227 yield defer.ensureDeferred(
228 self.store.set_appservice_state(service, ApplicationServiceState.DOWN)
229 )
230 yield defer.ensureDeferred(
231 self.store.set_appservice_state(service, ApplicationServiceState.UP)
232 )
224233 rows = yield self.db_pool.runQuery(
225234 self.engine.convert_param_style(
226235 "SELECT as_id FROM application_services_state WHERE state=?"
233242 def test_create_appservice_txn_first(self):
234243 service = Mock(id=self.as_list[0]["id"])
235244 events = [Mock(event_id="e1"), Mock(event_id="e2")]
236 txn = yield self.store.create_appservice_txn(service, events)
245 txn = yield defer.ensureDeferred(
246 self.store.create_appservice_txn(service, events)
247 )
237248 self.assertEquals(txn.id, 1)
238249 self.assertEquals(txn.events, events)
239250 self.assertEquals(txn.service, service)
245256 yield self._set_last_txn(service.id, 9643) # AS is falling behind
246257 yield self._insert_txn(service.id, 9644, events)
247258 yield self._insert_txn(service.id, 9645, events)
248 txn = yield self.store.create_appservice_txn(service, events)
259 txn = yield defer.ensureDeferred(
260 self.store.create_appservice_txn(service, events)
261 )
249262 self.assertEquals(txn.id, 9646)
250263 self.assertEquals(txn.events, events)
251264 self.assertEquals(txn.service, service)
255268 service = Mock(id=self.as_list[0]["id"])
256269 events = [Mock(event_id="e1"), Mock(event_id="e2")]
257270 yield self._set_last_txn(service.id, 9643)
258 txn = yield self.store.create_appservice_txn(service, events)
271 txn = yield defer.ensureDeferred(
272 self.store.create_appservice_txn(service, events)
273 )
259274 self.assertEquals(txn.id, 9644)
260275 self.assertEquals(txn.events, events)
261276 self.assertEquals(txn.service, service)
276291 yield self._insert_txn(self.as_list[2]["id"], 10, events)
277292 yield self._insert_txn(self.as_list[3]["id"], 9643, events)
278293
279 txn = yield self.store.create_appservice_txn(service, events)
294 txn = yield defer.ensureDeferred(
295 self.store.create_appservice_txn(service, events)
296 )
280297 self.assertEquals(txn.id, 9644)
281298 self.assertEquals(txn.events, events)
282299 self.assertEquals(txn.service, service)
288305 txn_id = 1
289306
290307 yield self._insert_txn(service.id, txn_id, events)
291 yield self.store.complete_appservice_txn(txn_id=txn_id, service=service)
308 yield defer.ensureDeferred(
309 self.store.complete_appservice_txn(txn_id=txn_id, service=service)
310 )
292311
293312 res = yield self.db_pool.runQuery(
294313 self.engine.convert_param_style(
314333 txn_id = 5
315334 yield self._set_last_txn(service.id, 4)
316335 yield self._insert_txn(service.id, txn_id, events)
317 yield self.store.complete_appservice_txn(txn_id=txn_id, service=service)
336 yield defer.ensureDeferred(
337 self.store.complete_appservice_txn(txn_id=txn_id, service=service)
338 )
318339
319340 res = yield self.db_pool.runQuery(
320341 self.engine.convert_param_style(
348369 other_events = [Mock(event_id="e5"), Mock(event_id="e6")]
349370
350371 # we aren't testing store._base stuff here, so mock this out
351 self.store.get_events_as_list = Mock(return_value=defer.succeed(events))
372 self.store.get_events_as_list = Mock(return_value=make_awaitable(events))
352373
353374 yield self._insert_txn(self.as_list[1]["id"], 9, other_events)
354375 yield self._insert_txn(service.id, 10, events)
00 from mock import Mock
1
2 from twisted.internet import defer
31
42 from synapse.storage.background_updates import BackgroundUpdater
53
3735 )
3836
3937 # first step: make a bit of progress
40 @defer.inlineCallbacks
41 def update(progress, count):
42 yield self.clock.sleep((count * duration_ms) / 1000)
38 async def update(progress, count):
39 await self.clock.sleep((count * duration_ms) / 1000)
4340 progress = {"my_key": progress["my_key"] + 1}
44 yield store.db_pool.runInteraction(
41 await store.db_pool.runInteraction(
4542 "update_progress",
4643 self.updates._background_update_progress_txn,
4744 "test_update",
6663
6764 # second step: complete the update
6865 # we should now get run with a much bigger number of items to update
69 @defer.inlineCallbacks
70 def update(progress, count):
66 async def update(progress, count):
7167 self.assertEqual(progress, {"my_key": 2})
7268 self.assertAlmostEqual(
7369 count, target_background_update_duration_ms / duration_ms, places=0,
7470 )
75 yield self.updates._end_background_update("test_update")
71 await self.updates._end_background_update("test_update")
7672 return count
7773
7874 self.update_handler.side_effect = update
6565 def test_insert_1col(self):
6666 self.mock_txn.rowcount = 1
6767
68 yield self.datastore.db_pool.simple_insert(
69 table="tablename", values={"columname": "Value"}
68 yield defer.ensureDeferred(
69 self.datastore.db_pool.simple_insert(
70 table="tablename", values={"columname": "Value"}
71 )
7072 )
7173
7274 self.mock_txn.execute.assert_called_with(
7779 def test_insert_3cols(self):
7880 self.mock_txn.rowcount = 1
7981
80 yield self.datastore.db_pool.simple_insert(
81 table="tablename",
82 # Use OrderedDict() so we can assert on the SQL generated
83 values=OrderedDict([("colA", 1), ("colB", 2), ("colC", 3)]),
82 yield defer.ensureDeferred(
83 self.datastore.db_pool.simple_insert(
84 table="tablename",
85 # Use OrderedDict() so we can assert on the SQL generated
86 values=OrderedDict([("colA", 1), ("colB", 2), ("colC", 3)]),
87 )
8488 )
8589
8690 self.mock_txn.execute.assert_called_with(
9296 self.mock_txn.rowcount = 1
9397 self.mock_txn.__iter__ = Mock(return_value=iter([("Value",)]))
9498
95 value = yield self.datastore.db_pool.simple_select_one_onecol(
96 table="tablename", keyvalues={"keycol": "TheKey"}, retcol="retcol"
99 value = yield defer.ensureDeferred(
100 self.datastore.db_pool.simple_select_one_onecol(
101 table="tablename", keyvalues={"keycol": "TheKey"}, retcol="retcol"
102 )
97103 )
98104
99105 self.assertEquals("Value", value)
106112 self.mock_txn.rowcount = 1
107113 self.mock_txn.fetchone.return_value = (1, 2, 3)
108114
109 ret = yield self.datastore.db_pool.simple_select_one(
110 table="tablename",
111 keyvalues={"keycol": "TheKey"},
112 retcols=["colA", "colB", "colC"],
115 ret = yield defer.ensureDeferred(
116 self.datastore.db_pool.simple_select_one(
117 table="tablename",
118 keyvalues={"keycol": "TheKey"},
119 retcols=["colA", "colB", "colC"],
120 )
113121 )
114122
115123 self.assertEquals({"colA": 1, "colB": 2, "colC": 3}, ret)
122130 self.mock_txn.rowcount = 0
123131 self.mock_txn.fetchone.return_value = None
124132
125 ret = yield self.datastore.db_pool.simple_select_one(
126 table="tablename",
127 keyvalues={"keycol": "Not here"},
128 retcols=["colA"],
129 allow_none=True,
133 ret = yield defer.ensureDeferred(
134 self.datastore.db_pool.simple_select_one(
135 table="tablename",
136 keyvalues={"keycol": "Not here"},
137 retcols=["colA"],
138 allow_none=True,
139 )
130140 )
131141
132142 self.assertFalse(ret)
137147 self.mock_txn.__iter__ = Mock(return_value=iter([(1,), (2,), (3,)]))
138148 self.mock_txn.description = (("colA", None, None, None, None, None, None),)
139149
140 ret = yield self.datastore.db_pool.simple_select_list(
141 table="tablename", keyvalues={"keycol": "A set"}, retcols=["colA"]
150 ret = yield defer.ensureDeferred(
151 self.datastore.db_pool.simple_select_list(
152 table="tablename", keyvalues={"keycol": "A set"}, retcols=["colA"]
153 )
142154 )
143155
144156 self.assertEquals([{"colA": 1}, {"colA": 2}, {"colA": 3}], ret)
150162 def test_update_one_1col(self):
151163 self.mock_txn.rowcount = 1
152164
153 yield self.datastore.db_pool.simple_update_one(
154 table="tablename",
155 keyvalues={"keycol": "TheKey"},
156 updatevalues={"columnname": "New Value"},
165 yield defer.ensureDeferred(
166 self.datastore.db_pool.simple_update_one(
167 table="tablename",
168 keyvalues={"keycol": "TheKey"},
169 updatevalues={"columnname": "New Value"},
170 )
157171 )
158172
159173 self.mock_txn.execute.assert_called_with(
165179 def test_update_one_4cols(self):
166180 self.mock_txn.rowcount = 1
167181
168 yield self.datastore.db_pool.simple_update_one(
169 table="tablename",
170 keyvalues=OrderedDict([("colA", 1), ("colB", 2)]),
171 updatevalues=OrderedDict([("colC", 3), ("colD", 4)]),
182 yield defer.ensureDeferred(
183 self.datastore.db_pool.simple_update_one(
184 table="tablename",
185 keyvalues=OrderedDict([("colA", 1), ("colB", 2)]),
186 updatevalues=OrderedDict([("colC", 3), ("colD", 4)]),
187 )
172188 )
173189
174190 self.mock_txn.execute.assert_called_with(
180196 def test_delete_one(self):
181197 self.mock_txn.rowcount = 1
182198
183 yield self.datastore.db_pool.simple_delete_one(
184 table="tablename", keyvalues={"keycol": "Go away"}
199 yield defer.ensureDeferred(
200 self.datastore.db_pool.simple_delete_one(
201 table="tablename", keyvalues={"keycol": "Go away"}
202 )
185203 )
186204
187205 self.mock_txn.execute.assert_called_with(
3737
3838 # Create a test user and room
3939 self.user = UserID("alice", "test")
40 self.requester = Requester(self.user, None, False, None, None)
40 self.requester = Requester(self.user, None, False, False, None, None)
4141 info, _ = self.get_success(self.room_creator.create_room(self.requester, {}))
4242 self.room_id = info["room_id"]
4343
259259 # Create a test user and room
260260 self.user = UserID.from_string(self.register_user("user1", "password"))
261261 self.token1 = self.login("user1", "password")
262 self.requester = Requester(self.user, None, False, None, None)
262 self.requester = Requester(self.user, None, False, False, None, None)
263263 info, _ = self.get_success(self.room_creator.create_room(self.requester, {}))
264264 self.room_id = info["room_id"]
265265 self.event_creator = homeserver.get_event_creation_handler()
270270
271271 # Pump the reactor repeatedly so that the background updates have a
272272 # chance to run.
273 self.pump(10 * 60)
273 self.pump(20)
274274
275275 latest_event_ids = self.get_success(
276276 self.store.get_latest_event_ids_in_room(self.room_id)
352352 self.event_creator_handler._rooms_to_exclude_from_dummy_event_insertion[
353353 "3"
354354 ] = 300000
355
355356 self.event_creator_handler._expire_rooms_to_exclude_from_dummy_event_insertion()
356357 # All entries within time frame
357358 self.assertEqual(
361362 3,
362363 )
363364 # Oldest room to expire
364 self.pump(1)
365 self.pump(1.01)
365366 self.event_creator_handler._expire_rooms_to_exclude_from_dummy_event_insertion()
366367 self.assertEqual(
367368 len(
1515
1616 from mock import Mock
1717
18 from twisted.internet import defer
19
2018 import synapse.rest.admin
2119 from synapse.http.site import XForwardedForRequest
2220 from synapse.rest.client.v1 import login
2321
2422 from tests import unittest
23 from tests.test_utils import make_awaitable
2524 from tests.unittest import override_config
2625
2726
154153 user_id = "@user:server"
155154
156155 self.store.get_monthly_active_count = Mock(
157 return_value=defer.succeed(lots_of_users)
156 side_effect=lambda: make_awaitable(lots_of_users)
158157 )
159158 self.get_success(
160159 self.store.insert_client_ip(
3737 self.store.store_device("user_id", "device_id", "display_name")
3838 )
3939
40 res = yield self.store.get_device("user_id", "device_id")
40 res = yield defer.ensureDeferred(self.store.get_device("user_id", "device_id"))
4141 self.assertDictContainsSubset(
4242 {
4343 "user_id": "user_id",
110110 self.store.store_device("user_id", "device_id", "display_name 1")
111111 )
112112
113 res = yield self.store.get_device("user_id", "device_id")
113 res = yield defer.ensureDeferred(self.store.get_device("user_id", "device_id"))
114114 self.assertEqual("display_name 1", res["display_name"])
115115
116116 # do a no-op first
117117 yield defer.ensureDeferred(self.store.update_device("user_id", "device_id"))
118 res = yield self.store.get_device("user_id", "device_id")
118 res = yield defer.ensureDeferred(self.store.get_device("user_id", "device_id"))
119119 self.assertEqual("display_name 1", res["display_name"])
120120
121121 # do the update
126126 )
127127
128128 # check it worked
129 res = yield self.store.get_device("user_id", "device_id")
129 res = yield defer.ensureDeferred(self.store.get_device("user_id", "device_id"))
130130 self.assertEqual("display_name 2", res["display_name"])
131131
132132 @defer.inlineCallbacks
4141
4242 self.assertEquals(
4343 ["#my-room:test"],
44 (yield self.store.get_aliases_for_room(self.room.to_string())),
44 (
45 yield defer.ensureDeferred(
46 self.store.get_aliases_for_room(self.room.to_string())
47 )
48 ),
4549 )
4650
4751 @defer.inlineCallbacks
3131
3232 yield defer.ensureDeferred(self.store.store_device("user", "device", None))
3333
34 yield self.store.set_e2e_device_keys("user", "device", now, json)
34 yield defer.ensureDeferred(
35 self.store.set_e2e_device_keys("user", "device", now, json)
36 )
3537
3638 res = yield defer.ensureDeferred(
37 self.store.get_e2e_device_keys((("user", "device"),))
39 self.store.get_e2e_device_keys_for_cs_api((("user", "device"),))
3840 )
3941 self.assertIn("user", res)
4042 self.assertIn("device", res["user"])
4850
4951 yield defer.ensureDeferred(self.store.store_device("user", "device", None))
5052
51 changed = yield self.store.set_e2e_device_keys("user", "device", now, json)
53 changed = yield defer.ensureDeferred(
54 self.store.set_e2e_device_keys("user", "device", now, json)
55 )
5256 self.assertTrue(changed)
5357
5458 # If we try to upload the same key then we should be told nothing
5559 # changed
56 changed = yield self.store.set_e2e_device_keys("user", "device", now, json)
60 changed = yield defer.ensureDeferred(
61 self.store.set_e2e_device_keys("user", "device", now, json)
62 )
5763 self.assertFalse(changed)
5864
5965 @defer.inlineCallbacks
6167 now = 1470174257070
6268 json = {"key": "value"}
6369
64 yield self.store.set_e2e_device_keys("user", "device", now, json)
70 yield defer.ensureDeferred(
71 self.store.set_e2e_device_keys("user", "device", now, json)
72 )
6573 yield defer.ensureDeferred(
6674 self.store.store_device("user", "device", "display_name")
6775 )
6876
6977 res = yield defer.ensureDeferred(
70 self.store.get_e2e_device_keys((("user", "device"),))
78 self.store.get_e2e_device_keys_for_cs_api((("user", "device"),))
7179 )
7280 self.assertIn("user", res)
7381 self.assertIn("device", res["user"])
8593 yield defer.ensureDeferred(self.store.store_device("user2", "device1", None))
8694 yield defer.ensureDeferred(self.store.store_device("user2", "device2", None))
8795
88 yield self.store.set_e2e_device_keys("user1", "device1", now, {"key": "json11"})
89 yield self.store.set_e2e_device_keys("user1", "device2", now, {"key": "json12"})
90 yield self.store.set_e2e_device_keys("user2", "device1", now, {"key": "json21"})
91 yield self.store.set_e2e_device_keys("user2", "device2", now, {"key": "json22"})
96 yield defer.ensureDeferred(
97 self.store.set_e2e_device_keys("user1", "device1", now, {"key": "json11"})
98 )
99 yield defer.ensureDeferred(
100 self.store.set_e2e_device_keys("user1", "device2", now, {"key": "json12"})
101 )
102 yield defer.ensureDeferred(
103 self.store.set_e2e_device_keys("user2", "device1", now, {"key": "json21"})
104 )
105 yield defer.ensureDeferred(
106 self.store.set_e2e_device_keys("user2", "device2", now, {"key": "json22"})
107 )
92108
93109 res = yield defer.ensureDeferred(
94 self.store.get_e2e_device_keys((("user1", "device1"), ("user2", "device2")))
110 self.store.get_e2e_device_keys_for_cs_api(
111 (("user1", "device1"), ("user2", "device2"))
112 )
95113 )
96114 self.assertIn("user1", res)
97115 self.assertIn("device1", res["user1"])
2626 room_creator = self.hs.get_room_creation_handler()
2727
2828 user = UserID("alice", "test")
29 requester = Requester(user, None, False, None, None)
29 requester = Requester(user, None, False, False, None, None)
3030
3131 # Real events, forward extremities
3232 events = [(3, 2), (6, 2), (4, 6)]
5959
6060 @defer.inlineCallbacks
6161 def _assert_counts(noitf_count, highlight_count):
62 counts = yield self.store.db_pool.runInteraction(
63 "", self.store._get_unread_counts_by_pos_txn, room_id, user_id, 0
62 counts = yield defer.ensureDeferred(
63 self.store.db_pool.runInteraction(
64 "", self.store._get_unread_counts_by_pos_txn, room_id, user_id, 0
65 )
6466 )
6567 self.assertEquals(
6668 counts,
67 {"notify_count": noitf_count, "highlight_count": highlight_count},
69 {
70 "notify_count": noitf_count,
71 "unread_count": 0, # Unread counts are tested in the sync tests.
72 "highlight_count": highlight_count,
73 },
6874 )
6975
7076 @defer.inlineCallbacks
7783
7884 yield defer.ensureDeferred(
7985 self.store.add_push_actions_to_staging(
80 event.event_id, {user_id: action}
81 )
82 )
83 yield self.store.db_pool.runInteraction(
84 "",
85 self.persist_events_store._set_push_actions_for_event_and_users_txn,
86 [(event, None)],
87 [(event, None)],
86 event.event_id, {user_id: action}, False,
87 )
88 )
89 yield defer.ensureDeferred(
90 self.store.db_pool.runInteraction(
91 "",
92 self.persist_events_store._set_push_actions_for_event_and_users_txn,
93 [(event, None)],
94 [(event, None)],
95 )
8896 )
8997
9098 def _rotate(stream):
91 return self.store.db_pool.runInteraction(
92 "", self.store._rotate_notifs_before_txn, stream
99 return defer.ensureDeferred(
100 self.store.db_pool.runInteraction(
101 "", self.store._rotate_notifs_before_txn, stream
102 )
93103 )
94104
95105 def _mark_read(stream, depth):
96 return self.store.db_pool.runInteraction(
97 "",
98 self.store._remove_old_push_actions_before_txn,
99 room_id,
100 user_id,
101 stream,
106 return defer.ensureDeferred(
107 self.store.db_pool.runInteraction(
108 "",
109 self.store._remove_old_push_actions_before_txn,
110 room_id,
111 user_id,
112 stream,
113 )
102114 )
103115
104116 yield _assert_counts(0, 0)
122134 yield _inject_actions(6, PlAIN_NOTIF)
123135 yield _rotate(7)
124136
125 yield self.store.db_pool.simple_delete(
126 table="event_push_actions", keyvalues={"1": 1}, desc=""
137 yield defer.ensureDeferred(
138 self.store.db_pool.simple_delete(
139 table="event_push_actions", keyvalues={"1": 1}, desc=""
140 )
127141 )
128142
129143 yield _assert_counts(1, 0)
141155 @defer.inlineCallbacks
142156 def test_find_first_stream_ordering_after_ts(self):
143157 def add_event(so, ts):
144 return self.store.db_pool.simple_insert(
145 "events",
146 {
147 "stream_ordering": so,
148 "received_ts": ts,
149 "event_id": "event%i" % so,
150 "type": "",
151 "room_id": "",
152 "content": "",
153 "processed": True,
154 "outlier": False,
155 "topological_ordering": 0,
156 "depth": 0,
157 },
158 return defer.ensureDeferred(
159 self.store.db_pool.simple_insert(
160 "events",
161 {
162 "stream_ordering": so,
163 "received_ts": ts,
164 "event_id": "event%i" % so,
165 "type": "",
166 "room_id": "",
167 "content": "",
168 "processed": True,
169 "outlier": False,
170 "topological_ordering": 0,
171 "depth": 0,
172 },
173 )
158174 )
159175
160176 # start with the base case where there are no events in the table
161 r = yield self.store.find_first_stream_ordering_after_ts(11)
177 r = yield defer.ensureDeferred(
178 self.store.find_first_stream_ordering_after_ts(11)
179 )
162180 self.assertEqual(r, 0)
163181
164182 # now with one event
165183 yield add_event(2, 10)
166 r = yield self.store.find_first_stream_ordering_after_ts(9)
184 r = yield defer.ensureDeferred(
185 self.store.find_first_stream_ordering_after_ts(9)
186 )
167187 self.assertEqual(r, 2)
168 r = yield self.store.find_first_stream_ordering_after_ts(10)
188 r = yield defer.ensureDeferred(
189 self.store.find_first_stream_ordering_after_ts(10)
190 )
169191 self.assertEqual(r, 2)
170 r = yield self.store.find_first_stream_ordering_after_ts(11)
192 r = yield defer.ensureDeferred(
193 self.store.find_first_stream_ordering_after_ts(11)
194 )
171195 self.assertEqual(r, 3)
172196
173197 # add a bunch of dummy events to the events table
180204 ):
181205 yield add_event(stream_ordering, ts)
182206
183 r = yield self.store.find_first_stream_ordering_after_ts(110)
207 r = yield defer.ensureDeferred(
208 self.store.find_first_stream_ordering_after_ts(110)
209 )
184210 self.assertEqual(r, 3, "First event after 110ms should be 3, was %i" % r)
185211
186212 # 4 and 5 are both after 120: we want 4 rather than 5
187 r = yield self.store.find_first_stream_ordering_after_ts(120)
213 r = yield defer.ensureDeferred(
214 self.store.find_first_stream_ordering_after_ts(120)
215 )
188216 self.assertEqual(r, 4, "First event after 120ms should be 4, was %i" % r)
189217
190 r = yield self.store.find_first_stream_ordering_after_ts(129)
218 r = yield defer.ensureDeferred(
219 self.store.find_first_stream_ordering_after_ts(129)
220 )
191221 self.assertEqual(r, 10, "First event after 129ms should be 10, was %i" % r)
192222
193223 # check we can get the last event
194 r = yield self.store.find_first_stream_ordering_after_ts(140)
224 r = yield defer.ensureDeferred(
225 self.store.find_first_stream_ordering_after_ts(140)
226 )
195227 self.assertEqual(r, 20, "First event after 14ms should be 20, was %i" % r)
196228
197229 # off the end
198 r = yield self.store.find_first_stream_ordering_after_ts(160)
230 r = yield defer.ensureDeferred(
231 self.store.find_first_stream_ordering_after_ts(160)
232 )
199233 self.assertEqual(r, 21)
200234
201235 # check we can find an event at ordering zero
202236 yield add_event(0, 5)
203 r = yield self.store.find_first_stream_ordering_after_ts(1)
237 r = yield defer.ensureDeferred(
238 self.store.find_first_stream_ordering_after_ts(1)
239 )
204240 self.assertEqual(r, 0)
5757 return self.get_success(self.db_pool.runWithConnection(_create))
5858
5959 def _insert_rows(self, instance_name: str, number: int):
60 """Insert N rows as the given instance, inserting with stream IDs pulled
61 from the postgres sequence.
62 """
63
6064 def _insert(txn):
6165 for _ in range(number):
6266 txn.execute(
6468 (instance_name,),
6569 )
6670
67 self.get_success(self.db_pool.runInteraction("test_single_instance", _insert))
71 self.get_success(self.db_pool.runInteraction("_insert_rows", _insert))
72
73 def _insert_row_with_id(self, instance_name: str, stream_id: int):
74 """Insert one row as the given instance with given stream_id, updating
75 the postgres sequence position to match.
76 """
77
78 def _insert(txn):
79 txn.execute(
80 "INSERT INTO foobar VALUES (?, ?)", (stream_id, instance_name,),
81 )
82 txn.execute("SELECT setval('foobar_seq', ?)", (stream_id,))
83
84 self.get_success(self.db_pool.runInteraction("_insert_row_with_id", _insert))
6885
6986 def test_empty(self):
7087 """Test an ID generator against an empty database gives sensible
87104 id_gen = self._create_id_generator()
88105
89106 self.assertEqual(id_gen.get_positions(), {"master": 7})
90 self.assertEqual(id_gen.get_current_token("master"), 7)
107 self.assertEqual(id_gen.get_current_token_for_writer("master"), 7)
91108
92109 # Try allocating a new ID gen and check that we only see position
93110 # advanced after we leave the context manager.
97114 self.assertEqual(stream_id, 8)
98115
99116 self.assertEqual(id_gen.get_positions(), {"master": 7})
100 self.assertEqual(id_gen.get_current_token("master"), 7)
117 self.assertEqual(id_gen.get_current_token_for_writer("master"), 7)
101118
102119 self.get_success(_get_next_async())
103120
104121 self.assertEqual(id_gen.get_positions(), {"master": 8})
105 self.assertEqual(id_gen.get_current_token("master"), 8)
122 self.assertEqual(id_gen.get_current_token_for_writer("master"), 8)
106123
107124 def test_multi_instance(self):
108125 """Test that reads and writes from multiple processes are handled
115132 second_id_gen = self._create_id_generator("second")
116133
117134 self.assertEqual(first_id_gen.get_positions(), {"first": 3, "second": 7})
118 self.assertEqual(first_id_gen.get_current_token("first"), 3)
119 self.assertEqual(first_id_gen.get_current_token("second"), 7)
135 self.assertEqual(first_id_gen.get_current_token_for_writer("first"), 3)
136 self.assertEqual(first_id_gen.get_current_token_for_writer("second"), 7)
120137
121138 # Try allocating a new ID gen and check that we only see position
122139 # advanced after we leave the context manager.
165182 id_gen = self._create_id_generator()
166183
167184 self.assertEqual(id_gen.get_positions(), {"master": 7})
168 self.assertEqual(id_gen.get_current_token("master"), 7)
185 self.assertEqual(id_gen.get_current_token_for_writer("master"), 7)
169186
170187 # Try allocating a new ID gen and check that we only see position
171188 # advanced after we leave the context manager.
175192 self.assertEqual(stream_id, 8)
176193
177194 self.assertEqual(id_gen.get_positions(), {"master": 7})
178 self.assertEqual(id_gen.get_current_token("master"), 7)
195 self.assertEqual(id_gen.get_current_token_for_writer("master"), 7)
179196
180197 self.get_success(self.db_pool.runInteraction("test", _get_next_txn))
181198
182199 self.assertEqual(id_gen.get_positions(), {"master": 8})
183 self.assertEqual(id_gen.get_current_token("master"), 8)
200 self.assertEqual(id_gen.get_current_token_for_writer("master"), 8)
201
202 def test_get_persisted_upto_position(self):
203 """Test that `get_persisted_upto_position` correctly tracks updates to
204 positions.
205 """
206
207 # The following tests are a bit cheeky in that we notify about new
208 # positions via `advance` without *actually* advancing the postgres
209 # sequence.
210
211 self._insert_row_with_id("first", 3)
212 self._insert_row_with_id("second", 5)
213
214 id_gen = self._create_id_generator("first")
215
216 self.assertEqual(id_gen.get_positions(), {"first": 3, "second": 5})
217
218 # Min is 3 and there is a gap between 5, so we expect it to be 3.
219 self.assertEqual(id_gen.get_persisted_upto_position(), 3)
220
221 # We advance "first" straight to 6. Min is now 5 but there is no gap so
222 # we expect it to be 6
223 id_gen.advance("first", 6)
224 self.assertEqual(id_gen.get_persisted_upto_position(), 6)
225
226 # No gap, so we expect 7.
227 id_gen.advance("second", 7)
228 self.assertEqual(id_gen.get_persisted_upto_position(), 7)
229
230 # We haven't seen 8 yet, so we expect 7 still.
231 id_gen.advance("second", 9)
232 self.assertEqual(id_gen.get_persisted_upto_position(), 7)
233
234 # Now that we've seen 7, 8 and 9 we can got straight to 9.
235 id_gen.advance("first", 8)
236 self.assertEqual(id_gen.get_persisted_upto_position(), 9)
237
238 # Jump forward with gaps. The minimum is 11, even though we haven't seen
239 # 10 we know that everything before 11 must be persisted.
240 id_gen.advance("first", 11)
241 id_gen.advance("second", 15)
242 self.assertEqual(id_gen.get_persisted_upto_position(), 11)
243
244 def test_get_persisted_upto_position_get_next(self):
245 """Test that `get_persisted_upto_position` correctly tracks updates to
246 positions when `get_next` is called.
247 """
248
249 self._insert_row_with_id("first", 3)
250 self._insert_row_with_id("second", 5)
251
252 id_gen = self._create_id_generator("first")
253
254 self.assertEqual(id_gen.get_positions(), {"first": 3, "second": 5})
255
256 self.assertEqual(id_gen.get_persisted_upto_position(), 3)
257 with self.get_success(id_gen.get_next()) as stream_id:
258 self.assertEqual(stream_id, 6)
259 self.assertEqual(id_gen.get_persisted_upto_position(), 3)
260
261 self.assertEqual(id_gen.get_persisted_upto_position(), 6)
262
263 # We assume that so long as `get_next` does correctly advance the
264 # `persisted_upto_position` in this case, then it will be correct in the
265 # other cases that are tested above (since they'll hit the same code).
266
267
268 class BackwardsMultiWriterIdGeneratorTestCase(HomeserverTestCase):
269 """Tests MultiWriterIdGenerator that produce *negative* stream IDs.
270 """
271
272 if not USE_POSTGRES_FOR_TESTS:
273 skip = "Requires Postgres"
274
275 def prepare(self, reactor, clock, hs):
276 self.store = hs.get_datastore()
277 self.db_pool = self.store.db_pool # type: DatabasePool
278
279 self.get_success(self.db_pool.runInteraction("_setup_db", self._setup_db))
280
281 def _setup_db(self, txn):
282 txn.execute("CREATE SEQUENCE foobar_seq")
283 txn.execute(
284 """
285 CREATE TABLE foobar (
286 stream_id BIGINT NOT NULL,
287 instance_name TEXT NOT NULL,
288 data TEXT
289 );
290 """
291 )
292
293 def _create_id_generator(self, instance_name="master") -> MultiWriterIdGenerator:
294 def _create(conn):
295 return MultiWriterIdGenerator(
296 conn,
297 self.db_pool,
298 instance_name=instance_name,
299 table="foobar",
300 instance_column="instance_name",
301 id_column="stream_id",
302 sequence_name="foobar_seq",
303 positive=False,
304 )
305
306 return self.get_success(self.db_pool.runWithConnection(_create))
307
308 def _insert_row(self, instance_name: str, stream_id: int):
309 """Insert one row as the given instance with given stream_id.
310 """
311
312 def _insert(txn):
313 txn.execute(
314 "INSERT INTO foobar VALUES (?, ?)", (stream_id, instance_name,),
315 )
316
317 self.get_success(self.db_pool.runInteraction("_insert_row", _insert))
318
319 def test_single_instance(self):
320 """Test that reads and writes from a single process are handled
321 correctly.
322 """
323 id_gen = self._create_id_generator()
324
325 with self.get_success(id_gen.get_next()) as stream_id:
326 self._insert_row("master", stream_id)
327
328 self.assertEqual(id_gen.get_positions(), {"master": -1})
329 self.assertEqual(id_gen.get_current_token_for_writer("master"), -1)
330 self.assertEqual(id_gen.get_persisted_upto_position(), -1)
331
332 with self.get_success(id_gen.get_next_mult(3)) as stream_ids:
333 for stream_id in stream_ids:
334 self._insert_row("master", stream_id)
335
336 self.assertEqual(id_gen.get_positions(), {"master": -4})
337 self.assertEqual(id_gen.get_current_token_for_writer("master"), -4)
338 self.assertEqual(id_gen.get_persisted_upto_position(), -4)
339
340 # Test loading from DB by creating a second ID gen
341 second_id_gen = self._create_id_generator()
342
343 self.assertEqual(second_id_gen.get_positions(), {"master": -4})
344 self.assertEqual(second_id_gen.get_current_token_for_writer("master"), -4)
345 self.assertEqual(second_id_gen.get_persisted_upto_position(), -4)
346
347 def test_multiple_instance(self):
348 """Tests that having multiple instances that get advanced over
349 federation works corretly.
350 """
351 id_gen_1 = self._create_id_generator("first")
352 id_gen_2 = self._create_id_generator("second")
353
354 with self.get_success(id_gen_1.get_next()) as stream_id:
355 self._insert_row("first", stream_id)
356 id_gen_2.advance("first", stream_id)
357
358 self.assertEqual(id_gen_1.get_positions(), {"first": -1})
359 self.assertEqual(id_gen_2.get_positions(), {"first": -1})
360 self.assertEqual(id_gen_1.get_persisted_upto_position(), -1)
361 self.assertEqual(id_gen_2.get_persisted_upto_position(), -1)
362
363 with self.get_success(id_gen_2.get_next()) as stream_id:
364 self._insert_row("second", stream_id)
365 id_gen_1.advance("second", stream_id)
366
367 self.assertEqual(id_gen_1.get_positions(), {"first": -1, "second": -2})
368 self.assertEqual(id_gen_2.get_positions(), {"first": -1, "second": -2})
369 self.assertEqual(id_gen_1.get_persisted_upto_position(), -2)
370 self.assertEqual(id_gen_2.get_persisted_upto_position(), -2)
3333
3434 @defer.inlineCallbacks
3535 def test_get_users_paginate(self):
36 yield self.store.register_user(self.user.to_string(), "pass")
37 yield self.store.create_profile(self.user.localpart)
38 yield self.store.set_profile_displayname(self.user.localpart, self.displayname)
36 yield defer.ensureDeferred(
37 self.store.register_user(self.user.to_string(), "pass")
38 )
39 yield defer.ensureDeferred(self.store.create_profile(self.user.localpart))
40 yield defer.ensureDeferred(
41 self.store.set_profile_displayname(self.user.localpart, self.displayname)
42 )
3943
40 users, total = yield self.store.get_users_paginate(
41 0, 10, name="bc", guests=False
44 users, total = yield defer.ensureDeferred(
45 self.store.get_users_paginate(0, 10, name="bc", guests=False)
4246 )
4347
4448 self.assertEquals(1, total)
3232
3333 @defer.inlineCallbacks
3434 def test_displayname(self):
35 yield self.store.create_profile(self.u_frank.localpart)
35 yield defer.ensureDeferred(self.store.create_profile(self.u_frank.localpart))
3636
37 yield self.store.set_profile_displayname(self.u_frank.localpart, "Frank")
37 yield defer.ensureDeferred(
38 self.store.set_profile_displayname(self.u_frank.localpart, "Frank")
39 )
3840
3941 self.assertEquals(
40 "Frank", (yield self.store.get_profile_displayname(self.u_frank.localpart))
42 "Frank",
43 (
44 yield defer.ensureDeferred(
45 self.store.get_profile_displayname(self.u_frank.localpart)
46 )
47 ),
4148 )
4249
4350 @defer.inlineCallbacks
4451 def test_avatar_url(self):
45 yield self.store.create_profile(self.u_frank.localpart)
52 yield defer.ensureDeferred(self.store.create_profile(self.u_frank.localpart))
4653
47 yield self.store.set_profile_avatar_url(
48 self.u_frank.localpart, "http://my.site/here"
54 yield defer.ensureDeferred(
55 self.store.set_profile_avatar_url(
56 self.u_frank.localpart, "http://my.site/here"
57 )
4958 )
5059
5160 self.assertEquals(
5261 "http://my.site/here",
53 (yield self.store.get_profile_avatar_url(self.u_frank.localpart)),
62 (
63 yield defer.ensureDeferred(
64 self.store.get_profile_avatar_url(self.u_frank.localpart)
65 )
66 ),
5467 )
1414
1515 from twisted.internet import defer
1616
17 from synapse.api.errors import NotFoundError
1718 from synapse.rest.client.v1 import room
1819
1920 from tests.unittest import HomeserverTestCase
4546 storage = self.hs.get_storage()
4647
4748 # Get the topological token
48 event = store.get_topological_token_for_event(last["event_id"])
49 self.pump()
50 event = self.successResultOf(event)
49 event = self.get_success(
50 store.get_topological_token_for_event(last["event_id"])
51 )
5152
5253 # Purge everything before this topological token
53 purge = defer.ensureDeferred(
54 storage.purge_events.purge_history(self.room_id, event, True)
55 )
56 self.pump()
57 self.assertEqual(self.successResultOf(purge), None)
58
59 # Try and get the events
60 get_first = store.get_event(first["event_id"])
61 get_second = store.get_event(second["event_id"])
62 get_third = store.get_event(third["event_id"])
63 get_last = store.get_event(last["event_id"])
64 self.pump()
54 self.get_success(storage.purge_events.purge_history(self.room_id, event, True))
6555
6656 # 1-3 should fail and last will succeed, meaning that 1-3 are deleted
6757 # and last is not.
68 self.failureResultOf(get_first)
69 self.failureResultOf(get_second)
70 self.failureResultOf(get_third)
71 self.successResultOf(get_last)
58 self.get_failure(store.get_event(first["event_id"]), NotFoundError)
59 self.get_failure(store.get_event(second["event_id"]), NotFoundError)
60 self.get_failure(store.get_event(third["event_id"]), NotFoundError)
61 self.get_success(store.get_event(last["event_id"]))
7262
7363 def test_purge_wont_delete_extrems(self):
7464 """
8373 storage = self.hs.get_datastore()
8474
8575 # Set the topological token higher than it should be
86 event = storage.get_topological_token_for_event(last["event_id"])
87 self.pump()
88 event = self.successResultOf(event)
76 event = self.get_success(
77 storage.get_topological_token_for_event(last["event_id"])
78 )
8979 event = "t{}-{}".format(
9080 *list(map(lambda x: x + 1, map(int, event[1:].split("-"))))
9181 )
9787 self.assertIn("greater than forward", f.value.args[0])
9888
9989 # Try and get the events
100 get_first = storage.get_event(first["event_id"])
101 get_second = storage.get_event(second["event_id"])
102 get_third = storage.get_event(third["event_id"])
103 get_last = storage.get_event(last["event_id"])
104 self.pump()
105
106 # Nothing is deleted.
107 self.successResultOf(get_first)
108 self.successResultOf(get_second)
109 self.successResultOf(get_third)
110 self.successResultOf(get_last)
90 self.get_success(storage.get_event(first["event_id"]))
91 self.get_success(storage.get_event(second["event_id"]))
92 self.get_success(storage.get_event(third["event_id"]))
93 self.get_success(storage.get_event(last["event_id"]))
1616 from twisted.internet import defer
1717
1818 from synapse.api.constants import UserTypes
19 from synapse.api.errors import ThreepidValidationError
1920
2021 from tests import unittest
2122 from tests.utils import setup_test_homeserver
3536
3637 @defer.inlineCallbacks
3738 def test_register(self):
38 yield self.store.register_user(self.user_id, self.pwhash)
39 yield defer.ensureDeferred(self.store.register_user(self.user_id, self.pwhash))
3940
4041 self.assertEquals(
4142 {
5152 "user_type": None,
5253 "deactivated": 0,
5354 },
54 (yield self.store.get_user_by_id(self.user_id)),
55 (yield defer.ensureDeferred(self.store.get_user_by_id(self.user_id))),
5556 )
5657
5758 @defer.inlineCallbacks
5859 def test_add_tokens(self):
59 yield self.store.register_user(self.user_id, self.pwhash)
60 yield defer.ensureDeferred(self.store.register_user(self.user_id, self.pwhash))
6061 yield defer.ensureDeferred(
6162 self.store.add_access_token_to_user(
6263 self.user_id, self.tokens[1], self.device_id, valid_until_ms=None
6364 )
6465 )
6566
66 result = yield self.store.get_user_by_access_token(self.tokens[1])
67 result = yield defer.ensureDeferred(
68 self.store.get_user_by_access_token(self.tokens[1])
69 )
6770
6871 self.assertDictContainsSubset(
6972 {"name": self.user_id, "device_id": self.device_id}, result
7477 @defer.inlineCallbacks
7578 def test_user_delete_access_tokens(self):
7679 # add some tokens
77 yield self.store.register_user(self.user_id, self.pwhash)
80 yield defer.ensureDeferred(self.store.register_user(self.user_id, self.pwhash))
7881 yield defer.ensureDeferred(
7982 self.store.add_access_token_to_user(
8083 self.user_id, self.tokens[0], device_id=None, valid_until_ms=None
8790 )
8891
8992 # now delete some
90 yield self.store.user_delete_access_tokens(
91 self.user_id, device_id=self.device_id
93 yield defer.ensureDeferred(
94 self.store.user_delete_access_tokens(self.user_id, device_id=self.device_id)
9295 )
9396
9497 # check they were deleted
95 user = yield self.store.get_user_by_access_token(self.tokens[1])
98 user = yield defer.ensureDeferred(
99 self.store.get_user_by_access_token(self.tokens[1])
100 )
96101 self.assertIsNone(user, "access token was not deleted by device_id")
97102
98103 # check the one not associated with the device was not deleted
99 user = yield self.store.get_user_by_access_token(self.tokens[0])
104 user = yield defer.ensureDeferred(
105 self.store.get_user_by_access_token(self.tokens[0])
106 )
100107 self.assertEqual(self.user_id, user["name"])
101108
102109 # now delete the rest
103 yield self.store.user_delete_access_tokens(self.user_id)
110 yield defer.ensureDeferred(self.store.user_delete_access_tokens(self.user_id))
104111
105 user = yield self.store.get_user_by_access_token(self.tokens[0])
112 user = yield defer.ensureDeferred(
113 self.store.get_user_by_access_token(self.tokens[0])
114 )
106115 self.assertIsNone(user, "access token was not deleted without device_id")
107116
108117 @defer.inlineCallbacks
110119 TEST_USER = "@test:test"
111120 SUPPORT_USER = "@support:test"
112121
113 res = yield self.store.is_support_user(None)
122 res = yield defer.ensureDeferred(self.store.is_support_user(None))
114123 self.assertFalse(res)
115 yield self.store.register_user(user_id=TEST_USER, password_hash=None)
116 res = yield self.store.is_support_user(TEST_USER)
124 yield defer.ensureDeferred(
125 self.store.register_user(user_id=TEST_USER, password_hash=None)
126 )
127 res = yield defer.ensureDeferred(self.store.is_support_user(TEST_USER))
117128 self.assertFalse(res)
118129
119 yield self.store.register_user(
120 user_id=SUPPORT_USER, password_hash=None, user_type=UserTypes.SUPPORT
130 yield defer.ensureDeferred(
131 self.store.register_user(
132 user_id=SUPPORT_USER, password_hash=None, user_type=UserTypes.SUPPORT
133 )
121134 )
122 res = yield self.store.is_support_user(SUPPORT_USER)
135 res = yield defer.ensureDeferred(self.store.is_support_user(SUPPORT_USER))
123136 self.assertTrue(res)
137
138 @defer.inlineCallbacks
139 def test_3pid_inhibit_invalid_validation_session_error(self):
140 """Tests that enabling the configuration option to inhibit 3PID errors on
141 /requestToken also inhibits validation errors caused by an unknown session ID.
142 """
143
144 # Check that, with the config setting set to false (the default value), a
145 # validation error is caused by the unknown session ID.
146 try:
147 yield defer.ensureDeferred(
148 self.store.validate_threepid_session(
149 "fake_sid", "fake_client_secret", "fake_token", 0,
150 )
151 )
152 except ThreepidValidationError as e:
153 self.assertEquals(e.msg, "Unknown session_id", e)
154
155 # Set the config setting to true.
156 self.store._ignore_unknown_session_error = True
157
158 # Check that now the validation error is caused by the token not matching.
159 try:
160 yield defer.ensureDeferred(
161 self.store.validate_threepid_session(
162 "fake_sid", "fake_client_secret", "fake_token", 0,
163 )
164 )
165 except ThreepidValidationError as e:
166 self.assertEquals(e.msg, "Validation token not found or has expired", e)
5353 "creator": self.u_creator.to_string(),
5454 "is_public": True,
5555 },
56 (yield self.store.get_room(self.room.to_string())),
56 (yield defer.ensureDeferred(self.store.get_room(self.room.to_string()))),
5757 )
5858
5959 @defer.inlineCallbacks
6060 def test_get_room_unknown_room(self):
61 self.assertIsNone((yield self.store.get_room("!uknown:test")),)
61 self.assertIsNone(
62 (yield defer.ensureDeferred(self.store.get_room("!uknown:test")))
63 )
6264
6365 @defer.inlineCallbacks
6466 def test_get_room_with_stats(self):
6870 "creator": self.u_creator.to_string(),
6971 "public": True,
7072 },
71 (yield self.store.get_room_with_stats(self.room.to_string())),
73 (
74 yield defer.ensureDeferred(
75 self.store.get_room_with_stats(self.room.to_string())
76 )
77 ),
7278 )
7379
7480 @defer.inlineCallbacks
7581 def test_get_room_with_stats_unknown_room(self):
76 self.assertIsNone((yield self.store.get_room_with_stats("!uknown:test")),)
82 self.assertIsNone(
83 (
84 yield defer.ensureDeferred(
85 self.store.get_room_with_stats("!uknown:test")
86 )
87 ),
88 )
7789
7890
7991 class RoomEventsStoreTestCase(unittest.TestCase):
8686 self.inject_room_member(self.room, self.u_bob, Membership.JOIN)
8787 self.inject_room_member(self.room, self.u_charlie.to_string(), Membership.JOIN)
8888
89 self.pump(20)
89 self.pump()
9090
9191 self.assertTrue("_known_servers_count" not in self.store.__dict__.keys())
9292
100100 # Initialises to 1 -- itself
101101 self.assertEqual(self.store._known_servers_count, 1)
102102
103 self.pump(20)
103 self.pump()
104104
105105 # No rooms have been joined, so technically the SQL returns 0, but it
106106 # will still say it knows about itself.
110110 self.inject_room_member(self.room, self.u_bob, Membership.JOIN)
111111 self.inject_room_member(self.room, self.u_charlie.to_string(), Membership.JOIN)
112112
113 self.pump(20)
113 self.pump(1)
114114
115115 # It now knows about Charlie's server.
116116 self.assertEqual(self.store._known_servers_count, 2)
186186
187187 # Now let's create a room, which will insert a membership
188188 user = UserID("alice", "test")
189 requester = Requester(user, None, False, None, None)
189 requester = Requester(user, None, False, False, None, None)
190190 self.get_success(self.room_creator.create_room(requester, {}))
191191
192192 # Register the background update to run again.
3030
3131 # alice and bob are both in !room_id. bobby is not but shares
3232 # a homeserver with alice.
33 yield self.store.update_profile_in_user_dir(ALICE, "alice", None)
34 yield self.store.update_profile_in_user_dir(BOB, "bob", None)
35 yield self.store.update_profile_in_user_dir(BOBBY, "bobby", None)
36 yield self.store.add_users_in_public_rooms("!room:id", (ALICE, BOB))
33 yield defer.ensureDeferred(
34 self.store.update_profile_in_user_dir(ALICE, "alice", None)
35 )
36 yield defer.ensureDeferred(
37 self.store.update_profile_in_user_dir(BOB, "bob", None)
38 )
39 yield defer.ensureDeferred(
40 self.store.update_profile_in_user_dir(BOBBY, "bobby", None)
41 )
42 yield defer.ensureDeferred(
43 self.store.add_users_in_public_rooms("!room:id", (ALICE, BOB))
44 )
3745
3846 @defer.inlineCallbacks
3947 def test_search_user_dir(self):
1414
1515 from mock import Mock
1616
17 from twisted.internet.defer import ensureDeferred, maybeDeferred, succeed
18
17 from twisted.internet.defer import succeed
18
19 from synapse.api.errors import FederationError
1920 from synapse.events import make_event_from_dict
2021 from synapse.logging.context import LoggingContext
2122 from synapse.types import Requester, UserID
4142 )
4243
4344 user_id = UserID("us", "test")
44 our_user = Requester(user_id, None, False, None, None)
45 our_user = Requester(user_id, None, False, False, None, None)
4546 room_creator = self.homeserver.get_room_creation_handler()
46 room_deferred = ensureDeferred(
47 self.room_id = self.get_success(
4748 room_creator.create_room(
4849 our_user, room_creator._presets_dict["public_chat"], ratelimit=False
4950 )
50 )
51 self.reactor.advance(0.1)
52 self.room_id = self.successResultOf(room_deferred)[0]["room_id"]
51 )[0]["room_id"]
5352
5453 self.store = self.homeserver.get_datastore()
5554
5655 # Figure out what the most recent event is
57 most_recent = self.successResultOf(
58 maybeDeferred(
59 self.homeserver.get_datastore().get_latest_event_ids_in_room,
60 self.room_id,
61 )
56 most_recent = self.get_success(
57 self.homeserver.get_datastore().get_latest_event_ids_in_room(self.room_id)
6258 )[0]
6359
6460 join_event = make_event_from_dict(
8884 )
8985
9086 # Send the join, it should return None (which is not an error)
91 d = ensureDeferred(
92 self.handler.on_receive_pdu(
93 "test.serv", join_event, sent_to_us_directly=True
94 )
95 )
96 self.reactor.advance(1)
97 self.assertEqual(self.successResultOf(d), None)
87 self.assertEqual(
88 self.get_success(
89 self.handler.on_receive_pdu(
90 "test.serv", join_event, sent_to_us_directly=True
91 )
92 ),
93 None,
94 )
9895
9996 # Make sure we actually joined the room
10097 self.assertEqual(
101 self.successResultOf(
102 maybeDeferred(self.store.get_latest_event_ids_in_room, self.room_id)
103 )[0],
98 self.get_success(self.store.get_latest_event_ids_in_room(self.room_id))[0],
10499 "$join:test.serv",
105100 )
106101
118113 self.http_client.post_json = post_json
119114
120115 # Figure out what the most recent event is
121 most_recent = self.successResultOf(
122 maybeDeferred(self.store.get_latest_event_ids_in_room, self.room_id)
116 most_recent = self.get_success(
117 self.store.get_latest_event_ids_in_room(self.room_id)
123118 )[0]
124119
125120 # Now lie about an event
139134 )
140135
141136 with LoggingContext(request="lying_event"):
142 d = ensureDeferred(
137 failure = self.get_failure(
143138 self.handler.on_receive_pdu(
144139 "test.serv", lying_event, sent_to_us_directly=True
145 )
146 )
147
148 # Step the reactor, so the database fetches come back
149 self.reactor.advance(1)
140 ),
141 FederationError,
142 )
150143
151144 # on_receive_pdu should throw an error
152 failure = self.failureResultOf(d)
153145 self.assertEqual(
154146 failure.value.args[0],
155147 (
159151 )
160152
161153 # Make sure the invalid event isn't there
162 extrem = maybeDeferred(self.store.get_latest_event_ids_in_room, self.room_id)
163 self.assertEqual(self.successResultOf(extrem)[0], "$join:test.serv")
154 extrem = self.get_success(self.store.get_latest_event_ids_in_room(self.room_id))
155 self.assertEqual(extrem[0], "$join:test.serv")
164156
165157 def test_retry_device_list_resync(self):
166158 """Tests that device lists are marked as stale if they couldn't be synced, and
177177
178178 self.assertEqual(channel.result["code"], b"200")
179179 self.assertNotIn("body", channel.result)
180 self.assertEqual(channel.headers.getRawHeaders(b"Content-Length"), [b"15"])
181180
182181
183182 class OptionsResourceTests(unittest.TestCase):
7070 return event
7171
7272
73 class StateGroupStore(object):
73 class StateGroupStore:
7474 def __init__(self):
7575 self._event_to_state_group = {}
7676 self._group_to_state = {}
7979
8080 self._next_group = 1
8181
82 def get_state_groups_ids(self, room_id, event_ids):
82 async def get_state_groups_ids(self, room_id, event_ids):
8383 groups = {}
8484 for event_id in event_ids:
8585 group = self._event_to_state_group.get(event_id)
8686 if group:
8787 groups[group] = self._group_to_state[group]
8888
89 return defer.succeed(groups)
90
91 def store_state_group(
89 return groups
90
91 async def store_state_group(
9292 self, event_id, room_id, prev_group, delta_ids, current_state_ids
9393 ):
9494 state_group = self._next_group
9696
9797 self._group_to_state[state_group] = dict(current_state_ids)
9898
99 return defer.succeed(state_group)
100
101 def get_events(self, event_ids, **kwargs):
102 return defer.succeed(
103 {
104 e_id: self._event_id_to_event[e_id]
105 for e_id in event_ids
106 if e_id in self._event_id_to_event
107 }
108 )
109
110 def get_state_group_delta(self, name):
111 return defer.succeed((None, None))
99 return state_group
100
101 async def get_events(self, event_ids, **kwargs):
102 return {
103 e_id: self._event_id_to_event[e_id]
104 for e_id in event_ids
105 if e_id in self._event_id_to_event
106 }
107
108 async def get_state_group_delta(self, name):
109 return (None, None)
112110
113111 def register_events(self, events):
114112 for e in events:
120118 def register_event_id_state_group(self, event_id, state_group):
121119 self._event_to_state_group[event_id] = state_group
122120
123 def get_room_version_id(self, room_id):
124 return defer.succeed(RoomVersions.V1.identifier)
121 async def get_room_version_id(self, room_id):
122 return RoomVersions.V1.identifier
125123
126124
127125 class DictObj(dict):
130128 self.__dict__ = self
131129
132130
133 class Graph(object):
131 class Graph:
134132 def __init__(self, nodes, edges):
135133 events = {}
136134 clobbered = set(events.keys())
475473 create_event(type="test2", state_key=""),
476474 ]
477475
478 group_name = yield self.store.store_state_group(
479 prev_event_id,
480 event.room_id,
481 None,
482 None,
483 {(e.type, e.state_key): e.event_id for e in old_state},
476 group_name = yield defer.ensureDeferred(
477 self.store.store_state_group(
478 prev_event_id,
479 event.room_id,
480 None,
481 None,
482 {(e.type, e.state_key): e.event_id for e in old_state},
483 )
484484 )
485485 self.store.register_event_id_state_group(prev_event_id, group_name)
486486
507507 create_event(type="test2", state_key=""),
508508 ]
509509
510 group_name = yield self.store.store_state_group(
511 prev_event_id,
512 event.room_id,
513 None,
514 None,
515 {(e.type, e.state_key): e.event_id for e in old_state},
510 group_name = yield defer.ensureDeferred(
511 self.store.store_state_group(
512 prev_event_id,
513 event.room_id,
514 None,
515 None,
516 {(e.type, e.state_key): e.event_id for e in old_state},
517 )
516518 )
517519 self.store.register_event_id_state_group(prev_event_id, group_name)
518520
690692 def _get_context(
691693 self, event, prev_event_id_1, old_state_1, prev_event_id_2, old_state_2
692694 ):
693 sg1 = yield self.store.store_state_group(
694 prev_event_id_1,
695 event.room_id,
696 None,
697 None,
698 {(e.type, e.state_key): e.event_id for e in old_state_1},
695 sg1 = yield defer.ensureDeferred(
696 self.store.store_state_group(
697 prev_event_id_1,
698 event.room_id,
699 None,
700 None,
701 {(e.type, e.state_key): e.event_id for e in old_state_1},
702 )
699703 )
700704 self.store.register_event_id_state_group(prev_event_id_1, sg1)
701705
702 sg2 = yield self.store.store_state_group(
703 prev_event_id_2,
704 event.room_id,
705 None,
706 None,
707 {(e.type, e.state_key): e.event_id for e in old_state_2},
706 sg2 = yield defer.ensureDeferred(
707 self.store.store_state_group(
708 prev_event_id_2,
709 event.room_id,
710 None,
711 None,
712 {(e.type, e.state_key): e.event_id for e in old_state_2},
713 )
708714 )
709715 self.store.register_event_id_state_group(prev_event_id_2, sg2)
710716
1212 # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
1313 # See the License for the specific language governing permissions and
1414 # limitations under the License.
15 from typing import Optional, Tuple
15 from typing import List, Optional, Tuple
1616
1717 import synapse.server
1818 from synapse.api.constants import EventTypes
1919 from synapse.api.room_versions import KNOWN_ROOM_VERSIONS
2020 from synapse.events import EventBase
2121 from synapse.events.snapshot import EventContext
22 from synapse.types import Collection
2322
2423 """
2524 Utility functions for poking events into the storage of the server under test.
5756 async def inject_event(
5857 hs: synapse.server.HomeServer,
5958 room_version: Optional[str] = None,
60 prev_event_ids: Optional[Collection[str]] = None,
59 prev_event_ids: Optional[List[str]] = None,
6160 **kwargs
6261 ) -> EventBase:
6362 """Inject a generic event into a room
7978 async def create_event(
8079 hs: synapse.server.HomeServer,
8180 room_version: Optional[str] = None,
82 prev_event_ids: Optional[Collection[str]] = None,
81 prev_event_ids: Optional[List[str]] = None,
8382 **kwargs
8483 ) -> Tuple[EventBase, EventContext]:
8584 if room_version is None:
3636 self.hs = yield setup_test_homeserver(self.addCleanup)
3737 self.event_creation_handler = self.hs.get_event_creation_handler()
3838 self.event_builder_factory = self.hs.get_event_builder_factory()
39 self.store = self.hs.get_datastore()
4039 self.storage = self.hs.get_storage()
4140
4241 yield defer.ensureDeferred(create_room(self.hs, TEST_ROOM_ID, "@someone:ROOM"))
9897 events_to_filter.append(evt)
9998
10099 # the erasey user gets erased
101 yield self.hs.get_datastore().mark_user_erased("@erased:local_hs")
100 yield defer.ensureDeferred(
101 self.hs.get_datastore().mark_user_erased("@erased:local_hs")
102 )
102103
103104 # ... and the filtering happens.
104105 filtered = yield defer.ensureDeferred(
292293 test_large_room.skip = "Disabled by default because it's slow"
293294
294295
295 class _TestStore(object):
296 class _TestStore:
296297 """Implements a few methods of the DataStore, so that we can test
297298 filter_events_for_server
298299
249249
250250 async def get_user_by_req(request, allow_guest=False, rights="access"):
251251 return create_requester(
252 UserID.from_string(self.helper.auth_user_id), 1, False, None
252 UserID.from_string(self.helper.auth_user_id),
253 1,
254 False,
255 False,
256 None,
253257 )
254258
255259 self.hs.get_auth().get_user_by_req = get_user_by_req
539543 """
540544 event_creator = self.hs.get_event_creation_handler()
541545 secrets = self.hs.get_secrets()
542 requester = Requester(user, None, False, None, None)
546 requester = Requester(user, None, False, False, None, None)
543547
544548 event, context = self.get_success(
545549 event_creator.create_event(
609613 """
610614
611615 def prepare(self, reactor, clock, homeserver):
612 class Authenticator(object):
616 class Authenticator:
613617 def authenticate_request(self, request, content):
614618 return succeed("other.example.com")
615619
8787 class DescriptorTestCase(unittest.TestCase):
8888 @defer.inlineCallbacks
8989 def test_cache(self):
90 class Cls(object):
90 class Cls:
9191 def __init__(self):
9292 self.mock = mock.Mock()
9393
121121 def test_cache_num_args(self):
122122 """Only the first num_args arguments should matter to the cache"""
123123
124 class Cls(object):
124 class Cls:
125125 def __init__(self):
126126 self.mock = mock.Mock()
127127
155155 """If the wrapped function throws synchronously, things should continue to work
156156 """
157157
158 class Cls(object):
158 class Cls:
159159 @cached()
160160 def fn(self, arg1):
161161 raise SynapseError(100, "mai spoon iz too big!!1")
179179
180180 complete_lookup = defer.Deferred()
181181
182 class Cls(object):
182 class Cls:
183183 @descriptors.cached()
184184 def fn(self, arg1):
185185 @defer.inlineCallbacks
222222 """Check that the cache sets and restores logcontexts correctly when
223223 the lookup function throws an exception"""
224224
225 class Cls(object):
225 class Cls:
226226 @descriptors.cached()
227227 def fn(self, arg1):
228228 @defer.inlineCallbacks
262262
263263 @defer.inlineCallbacks
264264 def test_cache_default_args(self):
265 class Cls(object):
265 class Cls:
266266 def __init__(self):
267267 self.mock = mock.Mock()
268268
299299 obj.mock.assert_not_called()
300300
301301 def test_cache_iterable(self):
302 class Cls(object):
302 class Cls:
303303 def __init__(self):
304304 self.mock = mock.Mock()
305305
335335 """If the wrapped function throws synchronously, things should continue to work
336336 """
337337
338 class Cls(object):
338 class Cls:
339339 @descriptors.cached(iterable=True)
340340 def fn(self, arg1):
341341 raise SynapseError(100, "mai spoon iz too big!!1")
357357 class CachedListDescriptorTestCase(unittest.TestCase):
358358 @defer.inlineCallbacks
359359 def test_cache(self):
360 class Cls(object):
360 class Cls:
361361 def __init__(self):
362362 self.mock = mock.Mock()
363363
365365 def fn(self, arg1, arg2):
366366 pass
367367
368 @descriptors.cachedList("fn", "args1", inlineCallbacks=True)
369 def list_fn(self, args1, arg2):
368 @descriptors.cachedList("fn", "args1")
369 async def list_fn(self, args1, arg2):
370370 assert current_context().request == "c1"
371371 # we want this to behave like an asynchronous function
372 yield run_on_reactor()
372 await run_on_reactor()
373373 assert current_context().request == "c1"
374374 return self.mock(args1, arg2)
375375
407407 def test_invalidate(self):
408408 """Make sure that invalidation callbacks are called."""
409409
410 class Cls(object):
410 class Cls:
411411 def __init__(self):
412412 self.mock = mock.Mock()
413413
415415 def fn(self, arg1, arg2):
416416 pass
417417
418 @descriptors.cachedList("fn", "args1", inlineCallbacks=True)
419 def list_fn(self, args1, arg2):
418 @descriptors.cachedList("fn", "args1")
419 async def list_fn(self, args1, arg2):
420420 # we want this to behave like an asynchronous function
421 yield run_on_reactor()
421 await run_on_reactor()
422422 return self.mock(args1, arg2)
423423
424424 obj = Cls()
111111 self.assertTrue(string_file.closed)
112112
113113
114 class DummyPullProducer(object):
114 class DummyPullProducer:
115115 def __init__(self):
116116 self.consumer = None
117117 self.deferred = defer.Deferred()
133133 return d
134134
135135
136 class BlockingStringWrite(object):
136 class BlockingStringWrite:
137137 def __init__(self):
138138 self.buffer = ""
139139 self.closed = False
9090 #
9191 # one more go, with success
9292 #
93 self.pump(MIN_RETRY_INTERVAL * RETRY_MULTIPLIER * 2.0)
93 self.reactor.advance(MIN_RETRY_INTERVAL * RETRY_MULTIPLIER * 2.0)
9494 limiter = self.get_success(get_retry_limiter("test_dest", self.clock, store))
9595
9696 self.pump(1)
1212 # See the License for the specific language governing permissions and
1313 # limitations under the License.
1414
15 from twisted.internet import defer
1516
1617 from synapse.util.async_helpers import ReadWriteLock
1718
4243 rwlock.read(key), # 5
4344 rwlock.write(key), # 6
4445 ]
46 ds = [defer.ensureDeferred(d) for d in ds]
4547
4648 self._assert_called_before_not_after(ds, 2)
4749
7274 with ds[6].result:
7375 pass
7476
75 d = rwlock.write(key)
77 d = defer.ensureDeferred(rwlock.write(key))
7678 self.assertTrue(d.called)
7779 with d.result:
7880 pass
7981
80 d = rwlock.read(key)
82 d = defer.ensureDeferred(rwlock.read(key))
8183 self.assertTrue(d.called)
8284 with d.result:
8385 pass
2727 "_--something==_",
2828 "...--==-18913",
2929 "8Dj2odd-e9asd.cd==_--ddas-secret-",
30 # We temporarily allow : characters: https://github.com/matrix-org/synapse/issues/6766
31 # To be removed in a future release
32 "SECRET:1234567890",
3330 ]
3431
3532 bad = [
471471 self.callbacks.append((method, path_pattern, callback))
472472
473473
474 class MockKey(object):
474 class MockKey:
475475 alg = "mock_alg"
476476 version = "mock_version"
477477 signature = b"\x9a\x87$"
490490 return b"<fake_encoded_key>"
491491
492492
493 class MockClock(object):
493 class MockClock:
494494 now = 1000
495495
496496 def __init__(self):
567567 )
568568
569569
570 class DeferredMockCallable(object):
570 class DeferredMockCallable:
571571 """A callable instance that stores a set of pending call expectations and
572572 return values for them. It allows a unit test to assert that the given set
573573 of function calls are eventually made, by awaiting on them to be called.
168168 skip_install = True
169169 deps =
170170 {[base]deps}
171 mypy==0.750
171 mypy==0.782
172172 mypy-zope
173 env =
174 MYPYPATH = stubs/
175173 extras = all
176 commands = mypy \
177 synapse/api \
178 synapse/appservice \
179 synapse/config \
180 synapse/event_auth.py \
181 synapse/events/builder.py \
182 synapse/events/spamcheck.py \
183 synapse/federation \
184 synapse/handlers/auth.py \
185 synapse/handlers/cas_handler.py \
186 synapse/handlers/directory.py \
187 synapse/handlers/federation.py \
188 synapse/handlers/identity.py \
189 synapse/handlers/message.py \
190 synapse/handlers/oidc_handler.py \
191 synapse/handlers/presence.py \
192 synapse/handlers/room_member.py \
193 synapse/handlers/room_member_worker.py \
194 synapse/handlers/saml_handler.py \
195 synapse/handlers/sync.py \
196 synapse/handlers/ui_auth \
197 synapse/http/server.py \
198 synapse/http/site.py \
199 synapse/logging/ \
200 synapse/metrics \
201 synapse/module_api \
202 synapse/notifier.py \
203 synapse/push/pusherpool.py \
204 synapse/push/push_rule_evaluator.py \
205 synapse/replication \
206 synapse/rest \
207 synapse/server.py \
208 synapse/server_notices \
209 synapse/spam_checker_api \
210 synapse/storage/databases/main/ui_auth.py \
211 synapse/storage/database.py \
212 synapse/storage/engines \
213 synapse/storage/state.py \
214 synapse/storage/util \
215 synapse/streams \
216 synapse/types.py \
217 synapse/util/caches/stream_change_cache.py \
218 synapse/util/metrics.py \
219 tests/replication \
220 tests/test_utils \
221 tests/rest/client/v2_alpha/test_auth.py \
222 tests/util/test_stream_change_cache.py
174 commands = mypy
223175
224176 # To find all folders that pass mypy you run:
225177 #