New upstream version 1.26.0
Andrej Shadura
3 years ago
14 | 14 | # limitations under the License. |
15 | 15 | |
16 | 16 | import logging |
17 | ||
17 | 18 | from synapse.storage.engines import create_engine |
18 | 19 | |
19 | 20 | logger = logging.getLogger("create_postgres_db") |
11 | 11 | _trial_temp/ |
12 | 12 | _trial_temp*/ |
13 | 13 | /out |
14 | .DS_Store | |
14 | 15 | |
15 | 16 | # stuff that is likely to exist when you run a server locally |
16 | 17 | /*.db |
17 | 18 | /*.log |
19 | /*.log.* | |
18 | 20 | /*.log.config |
19 | 21 | /*.pid |
20 | 22 | /.python-version |
0 | Synapse 1.26.0 (2021-01-27) | |
1 | =========================== | |
2 | ||
3 | This release brings a new schema version for Synapse and rolling back to a previous | |
4 | version is not trivial. Please review [UPGRADE.rst](UPGRADE.rst) for more details | |
5 | on these changes and for general upgrade guidance. | |
6 | ||
7 | No significant changes since 1.26.0rc2. | |
8 | ||
9 | ||
10 | Synapse 1.26.0rc2 (2021-01-25) | |
11 | ============================== | |
12 | ||
13 | Bugfixes | |
14 | -------- | |
15 | ||
16 | - Fix receipts and account data not being sent down sync. Introduced in v1.26.0rc1. ([\#9193](https://github.com/matrix-org/synapse/issues/9193), [\#9195](https://github.com/matrix-org/synapse/issues/9195)) | |
17 | - Fix chain cover update to handle events with duplicate auth events. Introduced in v1.26.0rc1. ([\#9210](https://github.com/matrix-org/synapse/issues/9210)) | |
18 | ||
19 | ||
20 | Internal Changes | |
21 | ---------------- | |
22 | ||
23 | - Add an `oidc-` prefix to any `idp_id`s which are given in the `oidc_providers` configuration. ([\#9189](https://github.com/matrix-org/synapse/issues/9189)) | |
24 | - Bump minimum `psycopg2` version to v2.8. ([\#9204](https://github.com/matrix-org/synapse/issues/9204)) | |
25 | ||
26 | ||
27 | Synapse 1.26.0rc1 (2021-01-20) | |
28 | ============================== | |
29 | ||
30 | This release brings a new schema version for Synapse and rolling back to a previous | |
31 | version is not trivial. Please review [UPGRADE.rst](UPGRADE.rst) for more details | |
32 | on these changes and for general upgrade guidance. | |
33 | ||
34 | Features | |
35 | -------- | |
36 | ||
37 | - Add support for multiple SSO Identity Providers. ([\#9015](https://github.com/matrix-org/synapse/issues/9015), [\#9017](https://github.com/matrix-org/synapse/issues/9017), [\#9036](https://github.com/matrix-org/synapse/issues/9036), [\#9067](https://github.com/matrix-org/synapse/issues/9067), [\#9081](https://github.com/matrix-org/synapse/issues/9081), [\#9082](https://github.com/matrix-org/synapse/issues/9082), [\#9105](https://github.com/matrix-org/synapse/issues/9105), [\#9107](https://github.com/matrix-org/synapse/issues/9107), [\#9109](https://github.com/matrix-org/synapse/issues/9109), [\#9110](https://github.com/matrix-org/synapse/issues/9110), [\#9127](https://github.com/matrix-org/synapse/issues/9127), [\#9153](https://github.com/matrix-org/synapse/issues/9153), [\#9154](https://github.com/matrix-org/synapse/issues/9154), [\#9177](https://github.com/matrix-org/synapse/issues/9177)) | |
38 | - During user-interactive authentication via single-sign-on, give a better error if the user uses the wrong account on the SSO IdP. ([\#9091](https://github.com/matrix-org/synapse/issues/9091)) | |
39 | - Give the `public_baseurl` a default value, if it is not explicitly set in the configuration file. ([\#9159](https://github.com/matrix-org/synapse/issues/9159)) | |
40 | - Improve performance when calculating ignored users in large rooms. ([\#9024](https://github.com/matrix-org/synapse/issues/9024)) | |
41 | - Implement [MSC2176](https://github.com/matrix-org/matrix-doc/pull/2176) in an experimental room version. ([\#8984](https://github.com/matrix-org/synapse/issues/8984)) | |
42 | - Add an admin API for protecting local media from quarantine. ([\#9086](https://github.com/matrix-org/synapse/issues/9086)) | |
43 | - Remove a user's avatar URL and display name when deactivated with the Admin API. ([\#8932](https://github.com/matrix-org/synapse/issues/8932)) | |
44 | - Update `/_synapse/admin/v1/users/<user_id>/joined_rooms` to work for both local and remote users. ([\#8948](https://github.com/matrix-org/synapse/issues/8948)) | |
45 | - Add experimental support for handling to-device messages on worker processes. ([\#9042](https://github.com/matrix-org/synapse/issues/9042), [\#9043](https://github.com/matrix-org/synapse/issues/9043), [\#9044](https://github.com/matrix-org/synapse/issues/9044), [\#9130](https://github.com/matrix-org/synapse/issues/9130)) | |
46 | - Add experimental support for handling `/keys/claim` and `/room_keys` APIs on worker processes. ([\#9068](https://github.com/matrix-org/synapse/issues/9068)) | |
47 | - Add experimental support for handling `/devices` API on worker processes. ([\#9092](https://github.com/matrix-org/synapse/issues/9092)) | |
48 | - Add experimental support for moving off receipts and account data persistence off master. ([\#9104](https://github.com/matrix-org/synapse/issues/9104), [\#9166](https://github.com/matrix-org/synapse/issues/9166)) | |
49 | ||
50 | ||
51 | Bugfixes | |
52 | -------- | |
53 | ||
54 | - Fix a long-standing issue where an internal server error would occur when requesting a profile over federation that did not include a display name / avatar URL. ([\#9023](https://github.com/matrix-org/synapse/issues/9023)) | |
55 | - Fix a long-standing bug where some caches could grow larger than configured. ([\#9028](https://github.com/matrix-org/synapse/issues/9028)) | |
56 | - Fix error handling during insertion of client IPs into the database. ([\#9051](https://github.com/matrix-org/synapse/issues/9051)) | |
57 | - Fix bug where we didn't correctly record CPU time spent in `on_new_event` block. ([\#9053](https://github.com/matrix-org/synapse/issues/9053)) | |
58 | - Fix a minor bug which could cause confusing error messages from invalid configurations. ([\#9054](https://github.com/matrix-org/synapse/issues/9054)) | |
59 | - Fix incorrect exit code when there is an error at startup. ([\#9059](https://github.com/matrix-org/synapse/issues/9059)) | |
60 | - Fix `JSONDecodeError` spamming the logs when sending transactions to remote servers. ([\#9070](https://github.com/matrix-org/synapse/issues/9070)) | |
61 | - Fix "Failed to send request" errors when a client provides an invalid room alias. ([\#9071](https://github.com/matrix-org/synapse/issues/9071)) | |
62 | - Fix bugs in federation catchup logic that caused outbound federation to be delayed for large servers after start up. Introduced in v1.8.0 and v1.21.0. ([\#9114](https://github.com/matrix-org/synapse/issues/9114), [\#9116](https://github.com/matrix-org/synapse/issues/9116)) | |
63 | - Fix corruption of `pushers` data when a postgres bouncer is used. ([\#9117](https://github.com/matrix-org/synapse/issues/9117)) | |
64 | - Fix minor bugs in handling the `clientRedirectUrl` parameter for SSO login. ([\#9128](https://github.com/matrix-org/synapse/issues/9128)) | |
65 | - Fix "Unhandled error in Deferred: BodyExceededMaxSize" errors when .well-known files that are too large. ([\#9108](https://github.com/matrix-org/synapse/issues/9108)) | |
66 | - Fix "UnboundLocalError: local variable 'length' referenced before assignment" errors when the response body exceeds the expected size. This bug was introduced in v1.25.0. ([\#9145](https://github.com/matrix-org/synapse/issues/9145)) | |
67 | - Fix a long-standing bug "ValueError: invalid literal for int() with base 10" when `/publicRooms` is requested with an invalid `server` parameter. ([\#9161](https://github.com/matrix-org/synapse/issues/9161)) | |
68 | ||
69 | ||
70 | Improved Documentation | |
71 | ---------------------- | |
72 | ||
73 | - Add some extra docs for getting Synapse running on macOS. ([\#8997](https://github.com/matrix-org/synapse/issues/8997)) | |
74 | - Correct a typo in the `systemd-with-workers` documentation. ([\#9035](https://github.com/matrix-org/synapse/issues/9035)) | |
75 | - Correct a typo in `INSTALL.md`. ([\#9040](https://github.com/matrix-org/synapse/issues/9040)) | |
76 | - Add missing `user_mapping_provider` configuration to the Keycloak OIDC example. Contributed by @chris-ruecker. ([\#9057](https://github.com/matrix-org/synapse/issues/9057)) | |
77 | - Quote `pip install` packages when extras are used to avoid shells interpreting bracket characters. ([\#9151](https://github.com/matrix-org/synapse/issues/9151)) | |
78 | ||
79 | ||
80 | Deprecations and Removals | |
81 | ------------------------- | |
82 | ||
83 | - Remove broken and unmaintained `demo/webserver.py` script. ([\#9039](https://github.com/matrix-org/synapse/issues/9039)) | |
84 | ||
85 | ||
86 | Internal Changes | |
87 | ---------------- | |
88 | ||
89 | - Improve efficiency of large state resolutions. ([\#8868](https://github.com/matrix-org/synapse/issues/8868), [\#9029](https://github.com/matrix-org/synapse/issues/9029), [\#9115](https://github.com/matrix-org/synapse/issues/9115), [\#9118](https://github.com/matrix-org/synapse/issues/9118), [\#9124](https://github.com/matrix-org/synapse/issues/9124)) | |
90 | - Various clean-ups to the structured logging and logging context code. ([\#8939](https://github.com/matrix-org/synapse/issues/8939)) | |
91 | - Ensure rejected events get added to some metadata tables. ([\#9016](https://github.com/matrix-org/synapse/issues/9016)) | |
92 | - Ignore date-rotated homeserver logs saved to disk. ([\#9018](https://github.com/matrix-org/synapse/issues/9018)) | |
93 | - Remove an unused column from `access_tokens` table. ([\#9025](https://github.com/matrix-org/synapse/issues/9025)) | |
94 | - Add a `-noextras` factor to `tox.ini`, to support running the tests with no optional dependencies. ([\#9030](https://github.com/matrix-org/synapse/issues/9030)) | |
95 | - Fix running unit tests when optional dependencies are not installed. ([\#9031](https://github.com/matrix-org/synapse/issues/9031)) | |
96 | - Allow bumping schema version when using split out state database. ([\#9033](https://github.com/matrix-org/synapse/issues/9033)) | |
97 | - Configure the linters to run on a consistent set of files. ([\#9038](https://github.com/matrix-org/synapse/issues/9038)) | |
98 | - Various cleanups to device inbox store. ([\#9041](https://github.com/matrix-org/synapse/issues/9041)) | |
99 | - Drop unused database tables. ([\#9055](https://github.com/matrix-org/synapse/issues/9055)) | |
100 | - Remove unused `SynapseService` class. ([\#9058](https://github.com/matrix-org/synapse/issues/9058)) | |
101 | - Remove unnecessary declarations in the tests for the admin API. ([\#9063](https://github.com/matrix-org/synapse/issues/9063)) | |
102 | - Remove `SynapseRequest.get_user_agent`. ([\#9069](https://github.com/matrix-org/synapse/issues/9069)) | |
103 | - Remove redundant `Homeserver.get_ip_from_request` method. ([\#9080](https://github.com/matrix-org/synapse/issues/9080)) | |
104 | - Add type hints to media repository. ([\#9093](https://github.com/matrix-org/synapse/issues/9093)) | |
105 | - Fix the wrong arguments being passed to `BlacklistingAgentWrapper` from `MatrixFederationAgent`. Contributed by Timothy Leung. ([\#9098](https://github.com/matrix-org/synapse/issues/9098)) | |
106 | - Reduce the scope of caught exceptions in `BlacklistingAgentWrapper`. ([\#9106](https://github.com/matrix-org/synapse/issues/9106)) | |
107 | - Improve `UsernamePickerTestCase`. ([\#9112](https://github.com/matrix-org/synapse/issues/9112)) | |
108 | - Remove dependency on `distutils`. ([\#9125](https://github.com/matrix-org/synapse/issues/9125)) | |
109 | - Enforce that replication HTTP clients are called with keyword arguments only. ([\#9144](https://github.com/matrix-org/synapse/issues/9144)) | |
110 | - Fix the Python 3.5 / old dependencies build in CI. ([\#9146](https://github.com/matrix-org/synapse/issues/9146)) | |
111 | - Replace the old `perspectives` option in the Synapse docker config file template with `trusted_key_servers`. ([\#9157](https://github.com/matrix-org/synapse/issues/9157)) | |
112 | ||
113 | ||
0 | 114 | Synapse 1.25.0 (2021-01-13) |
1 | 115 | =========================== |
2 | 116 |
189 | 189 | |
190 | 190 | ```sh |
191 | 191 | brew install openssl@1.1 |
192 | export LDFLAGS=-L/usr/local/Cellar/openssl\@1.1/1.1.1d/lib/ | |
192 | export LDFLAGS="-L/usr/local/opt/openssl/lib" | |
193 | export CPPFLAGS="-I/usr/local/opt/openssl/include" | |
193 | 194 | ``` |
194 | 195 | |
195 | 196 | ##### OpenSUSE |
256 | 257 | |
257 | 258 | #### Docker images and Ansible playbooks |
258 | 259 | |
259 | There is an offical synapse image available at | |
260 | There is an official synapse image available at | |
260 | 261 | <https://hub.docker.com/r/matrixdotorg/synapse> which can be used with |
261 | 262 | the docker-compose file available at [contrib/docker](contrib/docker). Further |
262 | 263 | information on this including configuration options is available in the README |
242 | 242 | Synapse Development |
243 | 243 | =================== |
244 | 244 | |
245 | Join our developer community on Matrix: [#synapse-dev:matrix.org](https://matrix.to/#/#synapse-dev:matrix.org) | |
245 | Join our developer community on Matrix: `#synapse-dev:matrix.org <https://matrix.to/#/#synapse-dev:matrix.org>`_ | |
246 | 246 | |
247 | 247 | Before setting up a development environment for synapse, make sure you have the |
248 | 248 | system dependencies (such as the python header files) installed - see |
278 | 278 | Ran 1337 tests in 716.064s |
279 | 279 | |
280 | 280 | PASSED (skips=15, successes=1322) |
281 | ||
282 | We recommend using the demo which starts 3 federated instances running on ports `8080` - `8082` | |
283 | ||
284 | ./demo/start.sh | |
285 | ||
286 | (to stop, you can use `./demo/stop.sh`) | |
287 | ||
288 | If you just want to start a single instance of the app and run it directly:: | |
289 | ||
290 | # Create the homeserver.yaml config once | |
291 | python -m synapse.app.homeserver \ | |
292 | --server-name my.domain.name \ | |
293 | --config-path homeserver.yaml \ | |
294 | --generate-config \ | |
295 | --report-stats=[yes|no] | |
296 | ||
297 | # Start the app | |
298 | python -m synapse.app.homeserver --config-path homeserver.yaml | |
299 | ||
300 | ||
301 | ||
281 | 302 | |
282 | 303 | Running the Integration Tests |
283 | 304 | ============================= |
83 | 83 | # replace `1.3.0` and `stretch` accordingly: |
84 | 84 | wget https://packages.matrix.org/debian/pool/main/m/matrix-synapse-py3/matrix-synapse-py3_1.3.0+stretch1_amd64.deb |
85 | 85 | dpkg -i matrix-synapse-py3_1.3.0+stretch1_amd64.deb |
86 | ||
87 | Upgrading to v1.26.0 | |
88 | ==================== | |
89 | ||
90 | Rolling back to v1.25.0 after a failed upgrade | |
91 | ---------------------------------------------- | |
92 | ||
93 | v1.26.0 includes a lot of large changes. If something problematic occurs, you | |
94 | may want to roll-back to a previous version of Synapse. Because v1.26.0 also | |
95 | includes a new database schema version, reverting that version is also required | |
96 | alongside the generic rollback instructions mentioned above. In short, to roll | |
97 | back to v1.25.0 you need to: | |
98 | ||
99 | 1. Stop the server | |
100 | 2. Decrease the schema version in the database: | |
101 | ||
102 | .. code:: sql | |
103 | ||
104 | UPDATE schema_version SET version = 58; | |
105 | ||
106 | 3. Delete the ignored users & chain cover data: | |
107 | ||
108 | .. code:: sql | |
109 | ||
110 | DROP TABLE IF EXISTS ignored_users; | |
111 | UPDATE rooms SET has_auth_chain_index = false; | |
112 | ||
113 | For PostgreSQL run: | |
114 | ||
115 | .. code:: sql | |
116 | ||
117 | TRUNCATE event_auth_chain_links; | |
118 | TRUNCATE event_auth_chains; | |
119 | ||
120 | For SQLite run: | |
121 | ||
122 | .. code:: sql | |
123 | ||
124 | DELETE FROM event_auth_chain_links; | |
125 | DELETE FROM event_auth_chains; | |
126 | ||
127 | 4. Mark the deltas as not run (so they will re-run on upgrade). | |
128 | ||
129 | .. code:: sql | |
130 | ||
131 | DELETE FROM applied_schema_deltas WHERE version = 59 AND file = "59/01ignored_user.py"; | |
132 | DELETE FROM applied_schema_deltas WHERE version = 59 AND file = "59/06chain_cover_index.sql"; | |
133 | ||
134 | 5. Downgrade Synapse by following the instructions for your installation method | |
135 | in the "Rolling back to older versions" section above. | |
86 | 136 | |
87 | 137 | Upgrading to v1.25.0 |
88 | 138 | ==================== |
0 | matrix-synapse-py3 (1.25.0ubuntu1) UNRELEASED; urgency=medium | |
1 | ||
2 | * Remove dependency on `python3-distutils`. | |
3 | ||
4 | -- Richard van der Hoff <richard@matrix.org> Fri, 15 Jan 2021 12:44:19 +0000 | |
5 | ||
0 | 6 | matrix-synapse-py3 (1.25.0) stable; urgency=medium |
1 | 7 | |
2 | 8 | [ Dan Callahan ] |
30 | 30 | Depends: |
31 | 31 | adduser, |
32 | 32 | debconf, |
33 | python3-distutils|libpython3-stdlib (<< 3.6), | |
34 | 33 | ${misc:Depends}, |
35 | 34 | ${shlibs:Depends}, |
36 | 35 | ${synapse:pydepends}, |
0 | import argparse | |
1 | import BaseHTTPServer | |
2 | import os | |
3 | import SimpleHTTPServer | |
4 | import cgi, logging | |
5 | ||
6 | from daemonize import Daemonize | |
7 | ||
8 | ||
9 | class SimpleHTTPRequestHandlerWithPOST(SimpleHTTPServer.SimpleHTTPRequestHandler): | |
10 | UPLOAD_PATH = "upload" | |
11 | ||
12 | """ | |
13 | Accept all post request as file upload | |
14 | """ | |
15 | ||
16 | def do_POST(self): | |
17 | ||
18 | path = os.path.join(self.UPLOAD_PATH, os.path.basename(self.path)) | |
19 | length = self.headers["content-length"] | |
20 | data = self.rfile.read(int(length)) | |
21 | ||
22 | with open(path, "wb") as fh: | |
23 | fh.write(data) | |
24 | ||
25 | self.send_response(200) | |
26 | self.send_header("Content-Type", "application/json") | |
27 | self.end_headers() | |
28 | ||
29 | # Return the absolute path of the uploaded file | |
30 | self.wfile.write('{"url":"/%s"}' % path) | |
31 | ||
32 | ||
33 | def setup(): | |
34 | parser = argparse.ArgumentParser() | |
35 | parser.add_argument("directory") | |
36 | parser.add_argument("-p", "--port", dest="port", type=int, default=8080) | |
37 | parser.add_argument("-P", "--pid-file", dest="pid", default="web.pid") | |
38 | args = parser.parse_args() | |
39 | ||
40 | # Get absolute path to directory to serve, as daemonize changes to '/' | |
41 | os.chdir(args.directory) | |
42 | dr = os.getcwd() | |
43 | ||
44 | httpd = BaseHTTPServer.HTTPServer(("", args.port), SimpleHTTPRequestHandlerWithPOST) | |
45 | ||
46 | def run(): | |
47 | os.chdir(dr) | |
48 | httpd.serve_forever() | |
49 | ||
50 | daemon = Daemonize( | |
51 | app="synapse-webclient", pid=args.pid, action=run, auto_close_fds=False | |
52 | ) | |
53 | ||
54 | daemon.start() | |
55 | ||
56 | ||
57 | if __name__ == "__main__": | |
58 | setup() |
197 | 197 | key_refresh_interval: "1d" # 1 Day. |
198 | 198 | |
199 | 199 | # The trusted servers to download signing keys from. |
200 | perspectives: | |
201 | servers: | |
202 | "matrix.org": | |
203 | verify_keys: | |
204 | "ed25519:auto": | |
205 | key: "Noi6WqcDj0QmPxCNQqgezwTlBKrfqehY1u2FyWP9uYw" | |
200 | trusted_key_servers: | |
201 | - server_name: matrix.org | |
202 | verify_keys: | |
203 | "ed25519:auto": "Noi6WqcDj0QmPxCNQqgezwTlBKrfqehY1u2FyWP9uYw" | |
206 | 204 | |
207 | 205 | password_config: |
208 | 206 | enabled: true |
3 | 3 | * [Quarantining media by ID](#quarantining-media-by-id) |
4 | 4 | * [Quarantining media in a room](#quarantining-media-in-a-room) |
5 | 5 | * [Quarantining all media of a user](#quarantining-all-media-of-a-user) |
6 | * [Protecting media from being quarantined](#protecting-media-from-being-quarantined) | |
6 | 7 | - [Delete local media](#delete-local-media) |
7 | 8 | * [Delete a specific local media](#delete-a-specific-local-media) |
8 | 9 | * [Delete local media by date or size](#delete-local-media-by-date-or-size) |
121 | 122 | The following fields are returned in the JSON response body: |
122 | 123 | |
123 | 124 | * `num_quarantined`: integer - The number of media items successfully quarantined |
125 | ||
126 | ## Protecting media from being quarantined | |
127 | ||
128 | This API protects a single piece of local media from being quarantined using the | |
129 | above APIs. This is useful for sticker packs and other shared media which you do | |
130 | not want to get quarantined, especially when | |
131 | [quarantining media in a room](#quarantining-media-in-a-room). | |
132 | ||
133 | Request: | |
134 | ||
135 | ``` | |
136 | POST /_synapse/admin/v1/media/protect/<media_id> | |
137 | ||
138 | {} | |
139 | ``` | |
140 | ||
141 | Where `media_id` is in the form of `abcdefg12345...`. | |
142 | ||
143 | Response: | |
144 | ||
145 | ```json | |
146 | {} | |
147 | ``` | |
124 | 148 | |
125 | 149 | # Delete local media |
126 | 150 | This API deletes the *local* media from the disk of your own server. |
97 | 97 | |
98 | 98 | - ``deactivated``, optional. If unspecified, deactivation state will be left |
99 | 99 | unchanged on existing accounts and set to ``false`` for new accounts. |
100 | A user cannot be erased by deactivating with this API. For details on deactivating users see | |
101 | `Deactivate Account <#deactivate-account>`_. | |
100 | 102 | |
101 | 103 | If the user already exists then optional parameters default to the current value. |
102 | 104 | |
247 | 249 | The erase parameter is optional and defaults to ``false``. |
248 | 250 | An empty body may be passed for backwards compatibility. |
249 | 251 | |
252 | The following actions are performed when deactivating an user: | |
253 | ||
254 | - Try to unpind 3PIDs from the identity server | |
255 | - Remove all 3PIDs from the homeserver | |
256 | - Delete all devices and E2EE keys | |
257 | - Delete all access tokens | |
258 | - Delete the password hash | |
259 | - Removal from all rooms the user is a member of | |
260 | - Remove the user from the user directory | |
261 | - Reject all pending invites | |
262 | - Remove all account validity information related to the user | |
263 | ||
264 | The following additional actions are performed during deactivation if``erase`` | |
265 | is set to ``true``: | |
266 | ||
267 | - Remove the user's display name | |
268 | - Remove the user's avatar URL | |
269 | - Mark the user as erased | |
270 | ||
250 | 271 | |
251 | 272 | Reset password |
252 | 273 | ============== |
335 | 356 | ], |
336 | 357 | "total": 2 |
337 | 358 | } |
359 | ||
360 | The server returns the list of rooms of which the user and the server | |
361 | are member. If the user is local, all the rooms of which the user is | |
362 | member are returned. | |
338 | 363 | |
339 | 364 | **Parameters** |
340 | 365 |
0 | digraph auth { | |
1 | nodesep=0.5; | |
2 | rankdir="RL"; | |
3 | ||
4 | C [label="Create (1,1)"]; | |
5 | ||
6 | BJ [label="Bob's Join (2,1)", color=red]; | |
7 | BJ2 [label="Bob's Join (2,2)", color=red]; | |
8 | BJ2 -> BJ [color=red, dir=none]; | |
9 | ||
10 | subgraph cluster_foo { | |
11 | A1 [label="Alice's invite (4,1)", color=blue]; | |
12 | A2 [label="Alice's Join (4,2)", color=blue]; | |
13 | A3 [label="Alice's Join (4,3)", color=blue]; | |
14 | A3 -> A2 -> A1 [color=blue, dir=none]; | |
15 | color=none; | |
16 | } | |
17 | ||
18 | PL1 [label="Power Level (3,1)", color=darkgreen]; | |
19 | PL2 [label="Power Level (3,2)", color=darkgreen]; | |
20 | PL2 -> PL1 [color=darkgreen, dir=none]; | |
21 | ||
22 | {rank = same; C; BJ; PL1; A1;} | |
23 | ||
24 | A1 -> C [color=grey]; | |
25 | A1 -> BJ [color=grey]; | |
26 | PL1 -> C [color=grey]; | |
27 | BJ2 -> PL1 [penwidth=2]; | |
28 | ||
29 | A3 -> PL2 [penwidth=2]; | |
30 | A1 -> PL1 -> BJ -> C [penwidth=2]; | |
31 | } |
Binary diff not shown
0 | # Auth Chain Difference Algorithm | |
1 | ||
2 | The auth chain difference algorithm is used by V2 state resolution, where a | |
3 | naive implementation can be a significant source of CPU and DB usage. | |
4 | ||
5 | ### Definitions | |
6 | ||
7 | A *state set* is a set of state events; e.g. the input of a state resolution | |
8 | algorithm is a collection of state sets. | |
9 | ||
10 | The *auth chain* of a set of events are all the events' auth events and *their* | |
11 | auth events, recursively (i.e. the events reachable by walking the graph induced | |
12 | by an event's auth events links). | |
13 | ||
14 | The *auth chain difference* of a collection of state sets is the union minus the | |
15 | intersection of the sets of auth chains corresponding to the state sets, i.e an | |
16 | event is in the auth chain difference if it is reachable by walking the auth | |
17 | event graph from at least one of the state sets but not from *all* of the state | |
18 | sets. | |
19 | ||
20 | ## Breadth First Walk Algorithm | |
21 | ||
22 | A way of calculating the auth chain difference without calculating the full auth | |
23 | chains for each state set is to do a parallel breadth first walk (ordered by | |
24 | depth) of each state set's auth chain. By tracking which events are reachable | |
25 | from each state set we can finish early if every pending event is reachable from | |
26 | every state set. | |
27 | ||
28 | This can work well for state sets that have a small auth chain difference, but | |
29 | can be very inefficient for larger differences. However, this algorithm is still | |
30 | used if we don't have a chain cover index for the room (e.g. because we're in | |
31 | the process of indexing it). | |
32 | ||
33 | ## Chain Cover Index | |
34 | ||
35 | Synapse computes auth chain differences by pre-computing a "chain cover" index | |
36 | for the auth chain in a room, allowing efficient reachability queries like "is | |
37 | event A in the auth chain of event B". This is done by assigning every event a | |
38 | *chain ID* and *sequence number* (e.g. `(5,3)`), and having a map of *links* | |
39 | between chains (e.g. `(5,3) -> (2,4)`) such that A is reachable by B (i.e. `A` | |
40 | is in the auth chain of `B`) if and only if either: | |
41 | ||
42 | 1. A and B have the same chain ID and `A`'s sequence number is less than `B`'s | |
43 | sequence number; or | |
44 | 2. there is a link `L` between `B`'s chain ID and `A`'s chain ID such that | |
45 | `L.start_seq_no` <= `B.seq_no` and `A.seq_no` <= `L.end_seq_no`. | |
46 | ||
47 | There are actually two potential implementations, one where we store links from | |
48 | each chain to every other reachable chain (the transitive closure of the links | |
49 | graph), and one where we remove redundant links (the transitive reduction of the | |
50 | links graph) e.g. if we have chains `C3 -> C2 -> C1` then the link `C3 -> C1` | |
51 | would not be stored. Synapse uses the former implementations so that it doesn't | |
52 | need to recurse to test reachability between chains. | |
53 | ||
54 | ### Example | |
55 | ||
56 | An example auth graph would look like the following, where chains have been | |
57 | formed based on type/state_key and are denoted by colour and are labelled with | |
58 | `(chain ID, sequence number)`. Links are denoted by the arrows (links in grey | |
59 | are those that would be remove in the second implementation described above). | |
60 | ||
61 | ![Example](auth_chain_diff.dot.png) | |
62 | ||
63 | Note that we don't include all links between events and their auth events, as | |
64 | most of those links would be redundant. For example, all events point to the | |
65 | create event, but each chain only needs the one link from it's base to the | |
66 | create event. | |
67 | ||
68 | ## Using the Index | |
69 | ||
70 | This index can be used to calculate the auth chain difference of the state sets | |
71 | by looking at the chain ID and sequence numbers reachable from each state set: | |
72 | ||
73 | 1. For every state set lookup the chain ID/sequence numbers of each state event | |
74 | 2. Use the index to find all chains and the maximum sequence number reachable | |
75 | from each state set. | |
76 | 3. The auth chain difference is then all events in each chain that have sequence | |
77 | numbers between the maximum sequence number reachable from *any* state set and | |
78 | the minimum reachable by *all* state sets (if any). | |
79 | ||
80 | Note that steps 2 is effectively calculating the auth chain for each state set | |
81 | (in terms of chain IDs and sequence numbers), and step 3 is calculating the | |
82 | difference between the union and intersection of the auth chains. | |
83 | ||
84 | ### Worked Example | |
85 | ||
86 | For example, given the above graph, we can calculate the difference between | |
87 | state sets consisting of: | |
88 | ||
89 | 1. `S1`: Alice's invite `(4,1)` and Bob's second join `(2,2)`; and | |
90 | 2. `S2`: Alice's second join `(4,3)` and Bob's first join `(2,1)`. | |
91 | ||
92 | Using the index we see that the following auth chains are reachable from each | |
93 | state set: | |
94 | ||
95 | 1. `S1`: `(1,1)`, `(2,2)`, `(3,1)` & `(4,1)` | |
96 | 2. `S2`: `(1,1)`, `(2,1)`, `(3,2)` & `(4,3)` | |
97 | ||
98 | And so, for each the ranges that are in the auth chain difference: | |
99 | 1. Chain 1: None, (since everything can reach the create event). | |
100 | 2. Chain 2: The range `(1, 2]` (i.e. just `2`), as `1` is reachable by all state | |
101 | sets and the maximum reachable is `2` (corresponding to Bob's second join). | |
102 | 3. Chain 3: Similarly the range `(1, 2]` (corresponding to the second power | |
103 | level). | |
104 | 4. Chain 4: The range `(1, 3]` (corresponding to both of Alice's joins). | |
105 | ||
106 | So the final result is: Bob's second join `(2,2)`, the second power level | |
107 | `(3,2)` and both of Alice's joins `(4,2)` & `(4,3)`. |
41 | 41 | * For other installation mechanisms, see the documentation provided by the |
42 | 42 | maintainer. |
43 | 43 | |
44 | To enable the OpenID integration, you should then add an `oidc_config` section | |
45 | to your configuration file (or uncomment the `enabled: true` line in the | |
46 | existing section). See [sample_config.yaml](./sample_config.yaml) for some | |
47 | sample settings, as well as the text below for example configurations for | |
48 | specific providers. | |
44 | To enable the OpenID integration, you should then add a section to the `oidc_providers` | |
45 | setting in your configuration file (or uncomment one of the existing examples). | |
46 | See [sample_config.yaml](./sample_config.yaml) for some sample settings, as well as | |
47 | the text below for example configurations for specific providers. | |
49 | 48 | |
50 | 49 | ## Sample configs |
51 | 50 | |
61 | 60 | Edit your Synapse config file and change the `oidc_config` section: |
62 | 61 | |
63 | 62 | ```yaml |
64 | oidc_config: | |
65 | enabled: true | |
66 | issuer: "https://login.microsoftonline.com/<tenant id>/v2.0" | |
67 | client_id: "<client id>" | |
68 | client_secret: "<client secret>" | |
69 | scopes: ["openid", "profile"] | |
70 | authorization_endpoint: "https://login.microsoftonline.com/<tenant id>/oauth2/v2.0/authorize" | |
71 | token_endpoint: "https://login.microsoftonline.com/<tenant id>/oauth2/v2.0/token" | |
72 | userinfo_endpoint: "https://graph.microsoft.com/oidc/userinfo" | |
73 | ||
74 | user_mapping_provider: | |
75 | config: | |
76 | localpart_template: "{{ user.preferred_username.split('@')[0] }}" | |
77 | display_name_template: "{{ user.name }}" | |
63 | oidc_providers: | |
64 | - idp_id: microsoft | |
65 | idp_name: Microsoft | |
66 | issuer: "https://login.microsoftonline.com/<tenant id>/v2.0" | |
67 | client_id: "<client id>" | |
68 | client_secret: "<client secret>" | |
69 | scopes: ["openid", "profile"] | |
70 | authorization_endpoint: "https://login.microsoftonline.com/<tenant id>/oauth2/v2.0/authorize" | |
71 | token_endpoint: "https://login.microsoftonline.com/<tenant id>/oauth2/v2.0/token" | |
72 | userinfo_endpoint: "https://graph.microsoft.com/oidc/userinfo" | |
73 | ||
74 | user_mapping_provider: | |
75 | config: | |
76 | localpart_template: "{{ user.preferred_username.split('@')[0] }}" | |
77 | display_name_template: "{{ user.name }}" | |
78 | 78 | ``` |
79 | 79 | |
80 | 80 | ### [Dex][dex-idp] |
102 | 102 | Synapse config: |
103 | 103 | |
104 | 104 | ```yaml |
105 | oidc_config: | |
106 | enabled: true | |
107 | skip_verification: true # This is needed as Dex is served on an insecure endpoint | |
108 | issuer: "http://127.0.0.1:5556/dex" | |
109 | client_id: "synapse" | |
110 | client_secret: "secret" | |
111 | scopes: ["openid", "profile"] | |
112 | user_mapping_provider: | |
113 | config: | |
114 | localpart_template: "{{ user.name }}" | |
115 | display_name_template: "{{ user.name|capitalize }}" | |
105 | oidc_providers: | |
106 | - idp_id: dex | |
107 | idp_name: "My Dex server" | |
108 | skip_verification: true # This is needed as Dex is served on an insecure endpoint | |
109 | issuer: "http://127.0.0.1:5556/dex" | |
110 | client_id: "synapse" | |
111 | client_secret: "secret" | |
112 | scopes: ["openid", "profile"] | |
113 | user_mapping_provider: | |
114 | config: | |
115 | localpart_template: "{{ user.name }}" | |
116 | display_name_template: "{{ user.name|capitalize }}" | |
116 | 117 | ``` |
117 | 118 | ### [Keycloak][keycloak-idp] |
118 | 119 | |
151 | 152 | 8. Copy Secret |
152 | 153 | |
153 | 154 | ```yaml |
154 | oidc_config: | |
155 | enabled: true | |
156 | issuer: "https://127.0.0.1:8443/auth/realms/{realm_name}" | |
157 | client_id: "synapse" | |
158 | client_secret: "copy secret generated from above" | |
159 | scopes: ["openid", "profile"] | |
155 | oidc_providers: | |
156 | - idp_id: keycloak | |
157 | idp_name: "My KeyCloak server" | |
158 | issuer: "https://127.0.0.1:8443/auth/realms/{realm_name}" | |
159 | client_id: "synapse" | |
160 | client_secret: "copy secret generated from above" | |
161 | scopes: ["openid", "profile"] | |
162 | user_mapping_provider: | |
163 | config: | |
164 | localpart_template: "{{ user.preferred_username }}" | |
165 | display_name_template: "{{ user.name }}" | |
160 | 166 | ``` |
161 | 167 | ### [Auth0][auth0] |
162 | 168 | |
186 | 192 | Synapse config: |
187 | 193 | |
188 | 194 | ```yaml |
189 | oidc_config: | |
190 | enabled: true | |
191 | issuer: "https://your-tier.eu.auth0.com/" # TO BE FILLED | |
192 | client_id: "your-client-id" # TO BE FILLED | |
193 | client_secret: "your-client-secret" # TO BE FILLED | |
194 | scopes: ["openid", "profile"] | |
195 | user_mapping_provider: | |
196 | config: | |
197 | localpart_template: "{{ user.preferred_username }}" | |
198 | display_name_template: "{{ user.name }}" | |
195 | oidc_providers: | |
196 | - idp_id: auth0 | |
197 | idp_name: Auth0 | |
198 | issuer: "https://your-tier.eu.auth0.com/" # TO BE FILLED | |
199 | client_id: "your-client-id" # TO BE FILLED | |
200 | client_secret: "your-client-secret" # TO BE FILLED | |
201 | scopes: ["openid", "profile"] | |
202 | user_mapping_provider: | |
203 | config: | |
204 | localpart_template: "{{ user.preferred_username }}" | |
205 | display_name_template: "{{ user.name }}" | |
199 | 206 | ``` |
200 | 207 | |
201 | 208 | ### GitHub |
214 | 221 | Synapse config: |
215 | 222 | |
216 | 223 | ```yaml |
217 | oidc_config: | |
218 | enabled: true | |
219 | discover: false | |
220 | issuer: "https://github.com/" | |
221 | client_id: "your-client-id" # TO BE FILLED | |
222 | client_secret: "your-client-secret" # TO BE FILLED | |
223 | authorization_endpoint: "https://github.com/login/oauth/authorize" | |
224 | token_endpoint: "https://github.com/login/oauth/access_token" | |
225 | userinfo_endpoint: "https://api.github.com/user" | |
226 | scopes: ["read:user"] | |
227 | user_mapping_provider: | |
228 | config: | |
229 | subject_claim: "id" | |
230 | localpart_template: "{{ user.login }}" | |
231 | display_name_template: "{{ user.name }}" | |
224 | oidc_providers: | |
225 | - idp_id: github | |
226 | idp_name: Github | |
227 | discover: false | |
228 | issuer: "https://github.com/" | |
229 | client_id: "your-client-id" # TO BE FILLED | |
230 | client_secret: "your-client-secret" # TO BE FILLED | |
231 | authorization_endpoint: "https://github.com/login/oauth/authorize" | |
232 | token_endpoint: "https://github.com/login/oauth/access_token" | |
233 | userinfo_endpoint: "https://api.github.com/user" | |
234 | scopes: ["read:user"] | |
235 | user_mapping_provider: | |
236 | config: | |
237 | subject_claim: "id" | |
238 | localpart_template: "{{ user.login }}" | |
239 | display_name_template: "{{ user.name }}" | |
232 | 240 | ``` |
233 | 241 | |
234 | 242 | ### [Google][google-idp] |
238 | 246 | 2. add an "OAuth Client ID" for a Web Application under "Credentials". |
239 | 247 | 3. Copy the Client ID and Client Secret, and add the following to your synapse config: |
240 | 248 | ```yaml |
241 | oidc_config: | |
242 | enabled: true | |
243 | issuer: "https://accounts.google.com/" | |
244 | client_id: "your-client-id" # TO BE FILLED | |
245 | client_secret: "your-client-secret" # TO BE FILLED | |
246 | scopes: ["openid", "profile"] | |
247 | user_mapping_provider: | |
248 | config: | |
249 | localpart_template: "{{ user.given_name|lower }}" | |
250 | display_name_template: "{{ user.name }}" | |
249 | oidc_providers: | |
250 | - idp_id: google | |
251 | idp_name: Google | |
252 | issuer: "https://accounts.google.com/" | |
253 | client_id: "your-client-id" # TO BE FILLED | |
254 | client_secret: "your-client-secret" # TO BE FILLED | |
255 | scopes: ["openid", "profile"] | |
256 | user_mapping_provider: | |
257 | config: | |
258 | localpart_template: "{{ user.given_name|lower }}" | |
259 | display_name_template: "{{ user.name }}" | |
251 | 260 | ``` |
252 | 261 | 4. Back in the Google console, add this Authorized redirect URI: `[synapse |
253 | 262 | public baseurl]/_synapse/oidc/callback`. |
261 | 270 | Synapse config: |
262 | 271 | |
263 | 272 | ```yaml |
264 | oidc_config: | |
265 | enabled: true | |
266 | issuer: "https://id.twitch.tv/oauth2/" | |
267 | client_id: "your-client-id" # TO BE FILLED | |
268 | client_secret: "your-client-secret" # TO BE FILLED | |
269 | client_auth_method: "client_secret_post" | |
270 | user_mapping_provider: | |
271 | config: | |
272 | localpart_template: "{{ user.preferred_username }}" | |
273 | display_name_template: "{{ user.name }}" | |
273 | oidc_providers: | |
274 | - idp_id: twitch | |
275 | idp_name: Twitch | |
276 | issuer: "https://id.twitch.tv/oauth2/" | |
277 | client_id: "your-client-id" # TO BE FILLED | |
278 | client_secret: "your-client-secret" # TO BE FILLED | |
279 | client_auth_method: "client_secret_post" | |
280 | user_mapping_provider: | |
281 | config: | |
282 | localpart_template: "{{ user.preferred_username }}" | |
283 | display_name_template: "{{ user.name }}" | |
274 | 284 | ``` |
275 | 285 | |
276 | 286 | ### GitLab |
282 | 292 | Synapse config: |
283 | 293 | |
284 | 294 | ```yaml |
285 | oidc_config: | |
286 | enabled: true | |
287 | issuer: "https://gitlab.com/" | |
288 | client_id: "your-client-id" # TO BE FILLED | |
289 | client_secret: "your-client-secret" # TO BE FILLED | |
290 | client_auth_method: "client_secret_post" | |
291 | scopes: ["openid", "read_user"] | |
292 | user_profile_method: "userinfo_endpoint" | |
293 | user_mapping_provider: | |
294 | config: | |
295 | localpart_template: '{{ user.nickname }}' | |
296 | display_name_template: '{{ user.name }}' | |
297 | ``` | |
295 | oidc_providers: | |
296 | - idp_id: gitlab | |
297 | idp_name: Gitlab | |
298 | issuer: "https://gitlab.com/" | |
299 | client_id: "your-client-id" # TO BE FILLED | |
300 | client_secret: "your-client-secret" # TO BE FILLED | |
301 | client_auth_method: "client_secret_post" | |
302 | scopes: ["openid", "read_user"] | |
303 | user_profile_method: "userinfo_endpoint" | |
304 | user_mapping_provider: | |
305 | config: | |
306 | localpart_template: '{{ user.nickname }}' | |
307 | display_name_template: '{{ user.name }}' | |
308 | ``` |
17 | 17 | virtualenv](../INSTALL.md#installing-from-source), you can install |
18 | 18 | the library with: |
19 | 19 | |
20 | ~/synapse/env/bin/pip install matrix-synapse[postgres] | |
20 | ~/synapse/env/bin/pip install "matrix-synapse[postgres]" | |
21 | 21 | |
22 | 22 | (substituting the path to your virtualenv for `~/synapse/env`, if |
23 | 23 | you used a different path). You will require the postgres |
66 | 66 | # |
67 | 67 | #web_client_location: https://riot.example.com/ |
68 | 68 | |
69 | # The public-facing base URL that clients use to access this HS | |
70 | # (not including _matrix/...). This is the same URL a user would | |
71 | # enter into the 'custom HS URL' field on their client. If you | |
72 | # use synapse with a reverse proxy, this should be the URL to reach | |
73 | # synapse via the proxy. | |
69 | # The public-facing base URL that clients use to access this Homeserver (not | |
70 | # including _matrix/...). This is the same URL a user might enter into the | |
71 | # 'Custom Homeserver URL' field on their client. If you use Synapse with a | |
72 | # reverse proxy, this should be the URL to reach Synapse via the proxy. | |
73 | # Otherwise, it should be the URL to reach Synapse's client HTTP listener (see | |
74 | # 'listeners' below). | |
75 | # | |
76 | # If this is left unset, it defaults to 'https://<server_name>/'. (Note that | |
77 | # that will not work unless you configure Synapse or a reverse-proxy to listen | |
78 | # on port 443.) | |
74 | 79 | # |
75 | 80 | #public_baseurl: https://example.com/ |
76 | 81 | |
1149 | 1154 | # send an email to the account's email address with a renewal link. By |
1150 | 1155 | # default, no such emails are sent. |
1151 | 1156 | # |
1152 | # If you enable this setting, you will also need to fill out the 'email' and | |
1153 | # 'public_baseurl' configuration sections. | |
1157 | # If you enable this setting, you will also need to fill out the 'email' | |
1158 | # configuration section. You should also check that 'public_baseurl' is set | |
1159 | # correctly. | |
1154 | 1160 | # |
1155 | 1161 | #renew_at: 1w |
1156 | 1162 | |
1241 | 1247 | # The identity server which we suggest that clients should use when users log |
1242 | 1248 | # in on this server. |
1243 | 1249 | # |
1244 | # (By default, no suggestion is made, so it is left up to the client. | |
1245 | # This setting is ignored unless public_baseurl is also set.) | |
1250 | # (By default, no suggestion is made, so it is left up to the client.) | |
1246 | 1251 | # |
1247 | 1252 | #default_identity_server: https://matrix.org |
1248 | 1253 | |
1266 | 1271 | # Servers handling the these requests must answer the `/requestToken` endpoints defined |
1267 | 1272 | # by the Matrix Identity Service API specification: |
1268 | 1273 | # https://matrix.org/docs/spec/identity_service/latest |
1269 | # | |
1270 | # If a delegate is specified, the config option public_baseurl must also be filled out. | |
1271 | 1274 | # |
1272 | 1275 | account_threepid_delegates: |
1273 | 1276 | #email: https://example.com # Delegate email sending to example.com |
1708 | 1711 | #idp_entityid: 'https://our_idp/entityid' |
1709 | 1712 | |
1710 | 1713 | |
1711 | # Enable OpenID Connect (OIDC) / OAuth 2.0 for registration and login. | |
1714 | # List of OpenID Connect (OIDC) / OAuth 2.0 identity providers, for registration | |
1715 | # and login. | |
1716 | # | |
1717 | # Options for each entry include: | |
1718 | # | |
1719 | # idp_id: a unique identifier for this identity provider. Used internally | |
1720 | # by Synapse; should be a single word such as 'github'. | |
1721 | # | |
1722 | # Note that, if this is changed, users authenticating via that provider | |
1723 | # will no longer be recognised as the same user! | |
1724 | # | |
1725 | # idp_name: A user-facing name for this identity provider, which is used to | |
1726 | # offer the user a choice of login mechanisms. | |
1727 | # | |
1728 | # idp_icon: An optional icon for this identity provider, which is presented | |
1729 | # by identity picker pages. If given, must be an MXC URI of the format | |
1730 | # mxc://<server-name>/<media-id>. (An easy way to obtain such an MXC URI | |
1731 | # is to upload an image to an (unencrypted) room and then copy the "url" | |
1732 | # from the source of the event.) | |
1733 | # | |
1734 | # discover: set to 'false' to disable the use of the OIDC discovery mechanism | |
1735 | # to discover endpoints. Defaults to true. | |
1736 | # | |
1737 | # issuer: Required. The OIDC issuer. Used to validate tokens and (if discovery | |
1738 | # is enabled) to discover the provider's endpoints. | |
1739 | # | |
1740 | # client_id: Required. oauth2 client id to use. | |
1741 | # | |
1742 | # client_secret: Required. oauth2 client secret to use. | |
1743 | # | |
1744 | # client_auth_method: auth method to use when exchanging the token. Valid | |
1745 | # values are 'client_secret_basic' (default), 'client_secret_post' and | |
1746 | # 'none'. | |
1747 | # | |
1748 | # scopes: list of scopes to request. This should normally include the "openid" | |
1749 | # scope. Defaults to ["openid"]. | |
1750 | # | |
1751 | # authorization_endpoint: the oauth2 authorization endpoint. Required if | |
1752 | # provider discovery is disabled. | |
1753 | # | |
1754 | # token_endpoint: the oauth2 token endpoint. Required if provider discovery is | |
1755 | # disabled. | |
1756 | # | |
1757 | # userinfo_endpoint: the OIDC userinfo endpoint. Required if discovery is | |
1758 | # disabled and the 'openid' scope is not requested. | |
1759 | # | |
1760 | # jwks_uri: URI where to fetch the JWKS. Required if discovery is disabled and | |
1761 | # the 'openid' scope is used. | |
1762 | # | |
1763 | # skip_verification: set to 'true' to skip metadata verification. Use this if | |
1764 | # you are connecting to a provider that is not OpenID Connect compliant. | |
1765 | # Defaults to false. Avoid this in production. | |
1766 | # | |
1767 | # user_profile_method: Whether to fetch the user profile from the userinfo | |
1768 | # endpoint. Valid values are: 'auto' or 'userinfo_endpoint'. | |
1769 | # | |
1770 | # Defaults to 'auto', which fetches the userinfo endpoint if 'openid' is | |
1771 | # included in 'scopes'. Set to 'userinfo_endpoint' to always fetch the | |
1772 | # userinfo endpoint. | |
1773 | # | |
1774 | # allow_existing_users: set to 'true' to allow a user logging in via OIDC to | |
1775 | # match a pre-existing account instead of failing. This could be used if | |
1776 | # switching from password logins to OIDC. Defaults to false. | |
1777 | # | |
1778 | # user_mapping_provider: Configuration for how attributes returned from a OIDC | |
1779 | # provider are mapped onto a matrix user. This setting has the following | |
1780 | # sub-properties: | |
1781 | # | |
1782 | # module: The class name of a custom mapping module. Default is | |
1783 | # 'synapse.handlers.oidc_handler.JinjaOidcMappingProvider'. | |
1784 | # See https://github.com/matrix-org/synapse/blob/master/docs/sso_mapping_providers.md#openid-mapping-providers | |
1785 | # for information on implementing a custom mapping provider. | |
1786 | # | |
1787 | # config: Configuration for the mapping provider module. This section will | |
1788 | # be passed as a Python dictionary to the user mapping provider | |
1789 | # module's `parse_config` method. | |
1790 | # | |
1791 | # For the default provider, the following settings are available: | |
1792 | # | |
1793 | # sub: name of the claim containing a unique identifier for the | |
1794 | # user. Defaults to 'sub', which OpenID Connect compliant | |
1795 | # providers should provide. | |
1796 | # | |
1797 | # localpart_template: Jinja2 template for the localpart of the MXID. | |
1798 | # If this is not set, the user will be prompted to choose their | |
1799 | # own username. | |
1800 | # | |
1801 | # display_name_template: Jinja2 template for the display name to set | |
1802 | # on first login. If unset, no displayname will be set. | |
1803 | # | |
1804 | # extra_attributes: a map of Jinja2 templates for extra attributes | |
1805 | # to send back to the client during login. | |
1806 | # Note that these are non-standard and clients will ignore them | |
1807 | # without modifications. | |
1808 | # | |
1809 | # When rendering, the Jinja2 templates are given a 'user' variable, | |
1810 | # which is set to the claims returned by the UserInfo Endpoint and/or | |
1811 | # in the ID Token. | |
1712 | 1812 | # |
1713 | 1813 | # See https://github.com/matrix-org/synapse/blob/master/docs/openid.md |
1714 | # for some example configurations. | |
1715 | # | |
1716 | oidc_config: | |
1717 | # Uncomment the following to enable authorization against an OpenID Connect | |
1718 | # server. Defaults to false. | |
1719 | # | |
1720 | #enabled: true | |
1721 | ||
1722 | # Uncomment the following to disable use of the OIDC discovery mechanism to | |
1723 | # discover endpoints. Defaults to true. | |
1724 | # | |
1725 | #discover: false | |
1726 | ||
1727 | # the OIDC issuer. Used to validate tokens and (if discovery is enabled) to | |
1728 | # discover the provider's endpoints. | |
1729 | # | |
1730 | # Required if 'enabled' is true. | |
1731 | # | |
1732 | #issuer: "https://accounts.example.com/" | |
1733 | ||
1734 | # oauth2 client id to use. | |
1735 | # | |
1736 | # Required if 'enabled' is true. | |
1737 | # | |
1738 | #client_id: "provided-by-your-issuer" | |
1739 | ||
1740 | # oauth2 client secret to use. | |
1741 | # | |
1742 | # Required if 'enabled' is true. | |
1743 | # | |
1744 | #client_secret: "provided-by-your-issuer" | |
1745 | ||
1746 | # auth method to use when exchanging the token. | |
1747 | # Valid values are 'client_secret_basic' (default), 'client_secret_post' and | |
1748 | # 'none'. | |
1749 | # | |
1750 | #client_auth_method: client_secret_post | |
1751 | ||
1752 | # list of scopes to request. This should normally include the "openid" scope. | |
1753 | # Defaults to ["openid"]. | |
1754 | # | |
1755 | #scopes: ["openid", "profile"] | |
1756 | ||
1757 | # the oauth2 authorization endpoint. Required if provider discovery is disabled. | |
1758 | # | |
1759 | #authorization_endpoint: "https://accounts.example.com/oauth2/auth" | |
1760 | ||
1761 | # the oauth2 token endpoint. Required if provider discovery is disabled. | |
1762 | # | |
1763 | #token_endpoint: "https://accounts.example.com/oauth2/token" | |
1764 | ||
1765 | # the OIDC userinfo endpoint. Required if discovery is disabled and the | |
1766 | # "openid" scope is not requested. | |
1767 | # | |
1768 | #userinfo_endpoint: "https://accounts.example.com/userinfo" | |
1769 | ||
1770 | # URI where to fetch the JWKS. Required if discovery is disabled and the | |
1771 | # "openid" scope is used. | |
1772 | # | |
1773 | #jwks_uri: "https://accounts.example.com/.well-known/jwks.json" | |
1774 | ||
1775 | # Uncomment to skip metadata verification. Defaults to false. | |
1776 | # | |
1777 | # Use this if you are connecting to a provider that is not OpenID Connect | |
1778 | # compliant. | |
1779 | # Avoid this in production. | |
1780 | # | |
1781 | #skip_verification: true | |
1782 | ||
1783 | # Whether to fetch the user profile from the userinfo endpoint. Valid | |
1784 | # values are: "auto" or "userinfo_endpoint". | |
1785 | # | |
1786 | # Defaults to "auto", which fetches the userinfo endpoint if "openid" is included | |
1787 | # in `scopes`. Uncomment the following to always fetch the userinfo endpoint. | |
1788 | # | |
1789 | #user_profile_method: "userinfo_endpoint" | |
1790 | ||
1791 | # Uncomment to allow a user logging in via OIDC to match a pre-existing account instead | |
1792 | # of failing. This could be used if switching from password logins to OIDC. Defaults to false. | |
1793 | # | |
1794 | #allow_existing_users: true | |
1795 | ||
1796 | # An external module can be provided here as a custom solution to mapping | |
1797 | # attributes returned from a OIDC provider onto a matrix user. | |
1798 | # | |
1799 | user_mapping_provider: | |
1800 | # The custom module's class. Uncomment to use a custom module. | |
1801 | # Default is 'synapse.handlers.oidc_handler.JinjaOidcMappingProvider'. | |
1802 | # | |
1803 | # See https://github.com/matrix-org/synapse/blob/master/docs/sso_mapping_providers.md#openid-mapping-providers | |
1804 | # for information on implementing a custom mapping provider. | |
1805 | # | |
1806 | #module: mapping_provider.OidcMappingProvider | |
1807 | ||
1808 | # Custom configuration values for the module. This section will be passed as | |
1809 | # a Python dictionary to the user mapping provider module's `parse_config` | |
1810 | # method. | |
1811 | # | |
1812 | # The examples below are intended for the default provider: they should be | |
1813 | # changed if using a custom provider. | |
1814 | # | |
1815 | config: | |
1816 | # name of the claim containing a unique identifier for the user. | |
1817 | # Defaults to `sub`, which OpenID Connect compliant providers should provide. | |
1818 | # | |
1819 | #subject_claim: "sub" | |
1820 | ||
1821 | # Jinja2 template for the localpart of the MXID. | |
1822 | # | |
1823 | # When rendering, this template is given the following variables: | |
1824 | # * user: The claims returned by the UserInfo Endpoint and/or in the ID | |
1825 | # Token | |
1826 | # | |
1827 | # If this is not set, the user will be prompted to choose their | |
1828 | # own username. | |
1829 | # | |
1830 | #localpart_template: "{{ user.preferred_username }}" | |
1831 | ||
1832 | # Jinja2 template for the display name to set on first login. | |
1833 | # | |
1834 | # If unset, no displayname will be set. | |
1835 | # | |
1836 | #display_name_template: "{{ user.given_name }} {{ user.last_name }}" | |
1837 | ||
1838 | # Jinja2 templates for extra attributes to send back to the client during | |
1839 | # login. | |
1840 | # | |
1841 | # Note that these are non-standard and clients will ignore them without modifications. | |
1842 | # | |
1843 | #extra_attributes: | |
1844 | #birthdate: "{{ user.birthdate }}" | |
1845 | ||
1814 | # for information on how to configure these options. | |
1815 | # | |
1816 | # For backwards compatibility, it is also possible to configure a single OIDC | |
1817 | # provider via an 'oidc_config' setting. This is now deprecated and admins are | |
1818 | # advised to migrate to the 'oidc_providers' format. (When doing that migration, | |
1819 | # use 'oidc' for the idp_id to ensure that existing users continue to be | |
1820 | # recognised.) | |
1821 | # | |
1822 | oidc_providers: | |
1823 | # Generic example | |
1824 | # | |
1825 | #- idp_id: my_idp | |
1826 | # idp_name: "My OpenID provider" | |
1827 | # idp_icon: "mxc://example.com/mediaid" | |
1828 | # discover: false | |
1829 | # issuer: "https://accounts.example.com/" | |
1830 | # client_id: "provided-by-your-issuer" | |
1831 | # client_secret: "provided-by-your-issuer" | |
1832 | # client_auth_method: client_secret_post | |
1833 | # scopes: ["openid", "profile"] | |
1834 | # authorization_endpoint: "https://accounts.example.com/oauth2/auth" | |
1835 | # token_endpoint: "https://accounts.example.com/oauth2/token" | |
1836 | # userinfo_endpoint: "https://accounts.example.com/userinfo" | |
1837 | # jwks_uri: "https://accounts.example.com/.well-known/jwks.json" | |
1838 | # skip_verification: true | |
1839 | ||
1840 | # For use with Keycloak | |
1841 | # | |
1842 | #- idp_id: keycloak | |
1843 | # idp_name: Keycloak | |
1844 | # issuer: "https://127.0.0.1:8443/auth/realms/my_realm_name" | |
1845 | # client_id: "synapse" | |
1846 | # client_secret: "copy secret generated in Keycloak UI" | |
1847 | # scopes: ["openid", "profile"] | |
1848 | ||
1849 | # For use with Github | |
1850 | # | |
1851 | #- idp_id: github | |
1852 | # idp_name: Github | |
1853 | # discover: false | |
1854 | # issuer: "https://github.com/" | |
1855 | # client_id: "your-client-id" # TO BE FILLED | |
1856 | # client_secret: "your-client-secret" # TO BE FILLED | |
1857 | # authorization_endpoint: "https://github.com/login/oauth/authorize" | |
1858 | # token_endpoint: "https://github.com/login/oauth/access_token" | |
1859 | # userinfo_endpoint: "https://api.github.com/user" | |
1860 | # scopes: ["read:user"] | |
1861 | # user_mapping_provider: | |
1862 | # config: | |
1863 | # subject_claim: "id" | |
1864 | # localpart_template: "{ user.login }" | |
1865 | # display_name_template: "{ user.name }" | |
1846 | 1866 | |
1847 | 1867 | |
1848 | 1868 | # Enable Central Authentication Service (CAS) for registration and login. |
1892 | 1912 | # phishing attacks from evil.site. To avoid this, include a slash after the |
1893 | 1913 | # hostname: "https://my.client/". |
1894 | 1914 | # |
1895 | # If public_baseurl is set, then the login fallback page (used by clients | |
1896 | # that don't natively support the required login flows) is whitelisted in | |
1897 | # addition to any URLs in this list. | |
1915 | # The login fallback page (used by clients that don't natively support the | |
1916 | # required login flows) is automatically whitelisted in addition to any URLs | |
1917 | # in this list. | |
1898 | 1918 | # |
1899 | 1919 | # By default, this list is empty. |
1900 | 1920 | # |
1907 | 1927 | # directory, default templates from within the Synapse package will be used. |
1908 | 1928 | # |
1909 | 1929 | # Synapse will look for the following templates in this directory: |
1930 | # | |
1931 | # * HTML page to prompt the user to choose an Identity Provider during | |
1932 | # login: 'sso_login_idp_picker.html'. | |
1933 | # | |
1934 | # This is only used if multiple SSO Identity Providers are configured. | |
1935 | # | |
1936 | # When rendering, this template is given the following variables: | |
1937 | # * redirect_url: the URL that the user will be redirected to after | |
1938 | # login. Needs manual escaping (see | |
1939 | # https://jinja.palletsprojects.com/en/2.11.x/templates/#html-escaping). | |
1940 | # | |
1941 | # * server_name: the homeserver's name. | |
1942 | # | |
1943 | # * providers: a list of available Identity Providers. Each element is | |
1944 | # an object with the following attributes: | |
1945 | # * idp_id: unique identifier for the IdP | |
1946 | # * idp_name: user-facing name for the IdP | |
1947 | # | |
1948 | # The rendered HTML page should contain a form which submits its results | |
1949 | # back as a GET request, with the following query parameters: | |
1950 | # | |
1951 | # * redirectUrl: the client redirect URI (ie, the `redirect_url` passed | |
1952 | # to the template) | |
1953 | # | |
1954 | # * idp: the 'idp_id' of the chosen IDP. | |
1910 | 1955 | # |
1911 | 1956 | # * HTML page for a confirmation step before redirecting back to the client |
1912 | 1957 | # with the login token: 'sso_redirect_confirm.html'. |
1942 | 1987 | # (see https://matrix.org/docs/spec/client_server/r0.6.0#fallback). |
1943 | 1988 | # |
1944 | 1989 | # This template has no additional variables. |
1990 | # | |
1991 | # * HTML page shown after a user-interactive authentication session which | |
1992 | # does not map correctly onto the expected user: 'sso_auth_bad_user.html'. | |
1993 | # | |
1994 | # When rendering, this template is given the following variables: | |
1995 | # * server_name: the homeserver's name. | |
1996 | # * user_id_to_verify: the MXID of the user that we are trying to | |
1997 | # validate. | |
1945 | 1998 | # |
1946 | 1999 | # * HTML page shown during single sign-on if a deactivated user (according to Synapse's database) |
1947 | 2000 | # attempts to login: 'sso_account_deactivated.html'. |
30 | 30 | 1. Adjust synapse configuration files as above. |
31 | 31 | 1. Copy the `*.service` and `*.target` files in [system](system) to |
32 | 32 | `/etc/systemd/system`. |
33 | 1. Run `systemctl deamon-reload` to tell systemd to load the new unit files. | |
33 | 1. Run `systemctl daemon-reload` to tell systemd to load the new unit files. | |
34 | 34 | 1. Run `systemctl enable matrix-synapse.service`. This will configure the |
35 | 35 | synapse master process to be started as part of the `matrix-synapse.target` |
36 | 36 | target. |
14 | 14 | workers only work with PostgreSQL-based Synapse deployments. SQLite should only |
15 | 15 | be used for demo purposes and any admin considering workers should already be |
16 | 16 | running PostgreSQL. |
17 | ||
18 | See also https://matrix.org/blog/2020/11/03/how-we-fixed-synapses-scalability | |
19 | for a higher level overview. | |
17 | 20 | |
18 | 21 | ## Main process/worker communication |
19 | 22 | |
55 | 58 | virtualenv, these can be installed with: |
56 | 59 | |
57 | 60 | ```sh |
58 | pip install matrix-synapse[redis] | |
61 | pip install "matrix-synapse[redis]" | |
59 | 62 | ``` |
60 | 63 | |
61 | 64 | Note that these dependencies are included when synapse is installed with `pip |
213 | 216 | ^/_matrix/client/(api/v1|r0|unstable)/rooms/.*/members$ |
214 | 217 | ^/_matrix/client/(api/v1|r0|unstable)/rooms/.*/state$ |
215 | 218 | ^/_matrix/client/(api/v1|r0|unstable)/account/3pid$ |
219 | ^/_matrix/client/(api/v1|r0|unstable)/devices$ | |
216 | 220 | ^/_matrix/client/(api/v1|r0|unstable)/keys/query$ |
217 | 221 | ^/_matrix/client/(api/v1|r0|unstable)/keys/changes$ |
218 | 222 | ^/_matrix/client/versions$ |
99 | 99 | synapse/util/async_helpers.py, |
100 | 100 | synapse/util/caches, |
101 | 101 | synapse/util/metrics.py, |
102 | synapse/util/stringutils.py, | |
102 | 103 | tests/replication, |
103 | 104 | tests/test_utils, |
104 | 105 | tests/handlers/test_password_providers.py, |
106 | tests/rest/client/v1/test_login.py, | |
105 | 107 | tests/rest/client/v2_alpha/test_auth.py, |
106 | 108 | tests/util/test_stream_change_cache.py |
107 | 109 |
69 | 69 | |
70 | 70 | BOOLEAN_COLUMNS = { |
71 | 71 | "events": ["processed", "outlier", "contains_url"], |
72 | "rooms": ["is_public"], | |
72 | "rooms": ["is_public", "has_auth_chain_index"], | |
73 | 73 | "event_edges": ["is_state"], |
74 | 74 | "presence_list": ["accepted"], |
75 | 75 | "presence_stream": ["currently_active"], |
628 | 628 | await self._setup_state_group_id_seq() |
629 | 629 | await self._setup_user_id_seq() |
630 | 630 | await self._setup_events_stream_seqs() |
631 | await self._setup_device_inbox_seq() | |
631 | 632 | |
632 | 633 | # Step 3. Get tables. |
633 | 634 | self.progress.set_state("Fetching tables") |
909 | 910 | return await self.postgres_store.db_pool.runInteraction( |
910 | 911 | "_setup_events_stream_seqs", _setup_events_stream_seqs_set_pos, |
911 | 912 | ) |
913 | ||
914 | async def _setup_device_inbox_seq(self): | |
915 | """Set the device inbox sequence to the correct value. | |
916 | """ | |
917 | curr_local_id = await self.sqlite_store.db_pool.simple_select_one_onecol( | |
918 | table="device_inbox", | |
919 | keyvalues={}, | |
920 | retcol="COALESCE(MAX(stream_id), 1)", | |
921 | allow_none=True, | |
922 | ) | |
923 | ||
924 | curr_federation_id = await self.sqlite_store.db_pool.simple_select_one_onecol( | |
925 | table="device_federation_outbox", | |
926 | keyvalues={}, | |
927 | retcol="COALESCE(MAX(stream_id), 1)", | |
928 | allow_none=True, | |
929 | ) | |
930 | ||
931 | next_id = max(curr_local_id, curr_federation_id) + 1 | |
932 | ||
933 | def r(txn): | |
934 | txn.execute( | |
935 | "ALTER SEQUENCE device_inbox_sequence RESTART WITH %s", (next_id,) | |
936 | ) | |
937 | ||
938 | return self.postgres_store.db_pool.runInteraction("_setup_device_inbox_seq", r) | |
912 | 939 | |
913 | 940 | |
914 | 941 | ############################################## |
14 | 14 | |
15 | 15 | # Stub for frozendict. |
16 | 16 | |
17 | from typing import ( | |
18 | Any, | |
19 | Hashable, | |
20 | Iterable, | |
21 | Iterator, | |
22 | Mapping, | |
23 | overload, | |
24 | Tuple, | |
25 | TypeVar, | |
26 | ) | |
17 | from typing import Any, Hashable, Iterable, Iterator, Mapping, Tuple, TypeVar, overload | |
27 | 18 | |
28 | 19 | _KT = TypeVar("_KT", bound=Hashable) # Key type. |
29 | 20 | _VT = TypeVar("_VT") # Value type. |
6 | 6 | Callable, |
7 | 7 | Dict, |
8 | 8 | Hashable, |
9 | ItemsView, | |
10 | Iterable, | |
9 | 11 | Iterator, |
10 | Iterable, | |
11 | ItemsView, | |
12 | 12 | KeysView, |
13 | 13 | List, |
14 | 14 | Mapping, |
15 | 15 | Optional, |
16 | 16 | Sequence, |
17 | Tuple, | |
17 | 18 | Type, |
18 | 19 | TypeVar, |
19 | Tuple, | |
20 | 20 | Union, |
21 | 21 | ValuesView, |
22 | 22 | overload, |
15 | 15 | """Contains *incomplete* type hints for txredisapi. |
16 | 16 | """ |
17 | 17 | |
18 | from typing import List, Optional, Union, Type | |
18 | from typing import List, Optional, Type, Union | |
19 | 19 | |
20 | 20 | class RedisProtocol: |
21 | 21 | def publish(self, channel: str, message: bytes): ... |
47 | 47 | except ImportError: |
48 | 48 | pass |
49 | 49 | |
50 | __version__ = "1.25.0" | |
50 | __version__ = "1.26.0" | |
51 | 51 | |
52 | 52 | if bool(os.environ.get("SYNAPSE_TEST_PATCH_LOG_CONTEXTS", False)): |
53 | 53 | # We import here so that we don't have to install a bunch of deps when |
32 | 32 | from synapse.api.room_versions import KNOWN_ROOM_VERSIONS |
33 | 33 | from synapse.appservice import ApplicationService |
34 | 34 | from synapse.events import EventBase |
35 | from synapse.http import get_request_user_agent | |
35 | 36 | from synapse.http.site import SynapseRequest |
36 | 37 | from synapse.logging import opentracing as opentracing |
37 | 38 | from synapse.storage.databases.main.registration import TokenLookupResult |
185 | 186 | AuthError if access is denied for the user in the access token |
186 | 187 | """ |
187 | 188 | try: |
188 | ip_addr = self.hs.get_ip_from_request(request) | |
189 | user_agent = request.get_user_agent("") | |
189 | ip_addr = request.getClientIP() | |
190 | user_agent = get_request_user_agent(request) | |
190 | 191 | |
191 | 192 | access_token = self.get_access_token_from_request(request) |
192 | 193 | |
274 | 275 | return None, None |
275 | 276 | |
276 | 277 | if app_service.ip_range_whitelist: |
277 | ip_address = IPAddress(self.hs.get_ip_from_request(request)) | |
278 | ip_address = IPAddress(request.getClientIP()) | |
278 | 279 | if ip_address not in app_service.ip_range_whitelist: |
279 | 280 | return None, None |
280 | 281 |
50 | 50 | class RoomVersion: |
51 | 51 | """An object which describes the unique attributes of a room version.""" |
52 | 52 | |
53 | identifier = attr.ib() # str; the identifier for this version | |
54 | disposition = attr.ib() # str; one of the RoomDispositions | |
55 | event_format = attr.ib() # int; one of the EventFormatVersions | |
56 | state_res = attr.ib() # int; one of the StateResolutionVersions | |
57 | enforce_key_validity = attr.ib() # bool | |
53 | identifier = attr.ib(type=str) # the identifier for this version | |
54 | disposition = attr.ib(type=str) # one of the RoomDispositions | |
55 | event_format = attr.ib(type=int) # one of the EventFormatVersions | |
56 | state_res = attr.ib(type=int) # one of the StateResolutionVersions | |
57 | enforce_key_validity = attr.ib(type=bool) | |
58 | 58 | |
59 | 59 | # bool: before MSC2261/MSC2432, m.room.aliases had special auth rules and redaction rules |
60 | 60 | special_case_aliases_auth = attr.ib(type=bool) |
63 | 63 | # * Floats |
64 | 64 | # * NaN, Infinity, -Infinity |
65 | 65 | strict_canonicaljson = attr.ib(type=bool) |
66 | # bool: MSC2209: Check 'notifications' key while verifying | |
66 | # MSC2209: Check 'notifications' key while verifying | |
67 | 67 | # m.room.power_levels auth rules. |
68 | 68 | limit_notifications_power_levels = attr.ib(type=bool) |
69 | # MSC2174/MSC2176: Apply updated redaction rules algorithm. | |
70 | msc2176_redaction_rules = attr.ib(type=bool) | |
69 | 71 | |
70 | 72 | |
71 | 73 | class RoomVersions: |
78 | 80 | special_case_aliases_auth=True, |
79 | 81 | strict_canonicaljson=False, |
80 | 82 | limit_notifications_power_levels=False, |
83 | msc2176_redaction_rules=False, | |
81 | 84 | ) |
82 | 85 | V2 = RoomVersion( |
83 | 86 | "2", |
88 | 91 | special_case_aliases_auth=True, |
89 | 92 | strict_canonicaljson=False, |
90 | 93 | limit_notifications_power_levels=False, |
94 | msc2176_redaction_rules=False, | |
91 | 95 | ) |
92 | 96 | V3 = RoomVersion( |
93 | 97 | "3", |
98 | 102 | special_case_aliases_auth=True, |
99 | 103 | strict_canonicaljson=False, |
100 | 104 | limit_notifications_power_levels=False, |
105 | msc2176_redaction_rules=False, | |
101 | 106 | ) |
102 | 107 | V4 = RoomVersion( |
103 | 108 | "4", |
108 | 113 | special_case_aliases_auth=True, |
109 | 114 | strict_canonicaljson=False, |
110 | 115 | limit_notifications_power_levels=False, |
116 | msc2176_redaction_rules=False, | |
111 | 117 | ) |
112 | 118 | V5 = RoomVersion( |
113 | 119 | "5", |
118 | 124 | special_case_aliases_auth=True, |
119 | 125 | strict_canonicaljson=False, |
120 | 126 | limit_notifications_power_levels=False, |
127 | msc2176_redaction_rules=False, | |
121 | 128 | ) |
122 | 129 | V6 = RoomVersion( |
123 | 130 | "6", |
128 | 135 | special_case_aliases_auth=False, |
129 | 136 | strict_canonicaljson=True, |
130 | 137 | limit_notifications_power_levels=True, |
138 | msc2176_redaction_rules=False, | |
139 | ) | |
140 | MSC2176 = RoomVersion( | |
141 | "org.matrix.msc2176", | |
142 | RoomDisposition.UNSTABLE, | |
143 | EventFormatVersions.V3, | |
144 | StateResolutionVersions.V2, | |
145 | enforce_key_validity=True, | |
146 | special_case_aliases_auth=False, | |
147 | strict_canonicaljson=True, | |
148 | limit_notifications_power_levels=True, | |
149 | msc2176_redaction_rules=True, | |
131 | 150 | ) |
132 | 151 | |
133 | 152 | |
140 | 159 | RoomVersions.V4, |
141 | 160 | RoomVersions.V5, |
142 | 161 | RoomVersions.V6, |
162 | RoomVersions.MSC2176, | |
143 | 163 | ) |
144 | 164 | } # type: Dict[str, RoomVersion] |
41 | 41 | """ |
42 | 42 | if hs_config.form_secret is None: |
43 | 43 | raise ConfigError("form_secret not set in config") |
44 | if hs_config.public_baseurl is None: | |
45 | raise ConfigError("public_baseurl not set in config") | |
46 | 44 | |
47 | 45 | self._hmac_secret = hs_config.form_secret.encode("utf-8") |
48 | 46 | self._public_baseurl = hs_config.public_baseurl |
0 | 0 | # -*- coding: utf-8 -*- |
1 | 1 | # Copyright 2017 New Vector Ltd |
2 | # Copyright 2019-2021 The Matrix.org Foundation C.I.C | |
2 | 3 | # |
3 | 4 | # Licensed under the Apache License, Version 2.0 (the "License"); |
4 | 5 | # you may not use this file except in compliance with the License. |
18 | 19 | import socket |
19 | 20 | import sys |
20 | 21 | import traceback |
21 | from typing import Iterable | |
22 | from typing import Awaitable, Callable, Iterable | |
22 | 23 | |
23 | 24 | from typing_extensions import NoReturn |
24 | 25 | |
140 | 141 | sys.stderr.write(" %s\n" % (line.rstrip(),)) |
141 | 142 | sys.stderr.write("*" * line_length + "\n") |
142 | 143 | sys.exit(1) |
144 | ||
145 | ||
146 | def register_start(cb: Callable[..., Awaitable], *args, **kwargs) -> None: | |
147 | """Register a callback with the reactor, to be called once it is running | |
148 | ||
149 | This can be used to initialise parts of the system which require an asynchronous | |
150 | setup. | |
151 | ||
152 | Any exception raised by the callback will be printed and logged, and the process | |
153 | will exit. | |
154 | """ | |
155 | ||
156 | async def wrapper(): | |
157 | try: | |
158 | await cb(*args, **kwargs) | |
159 | except Exception: | |
160 | # previously, we used Failure().printTraceback() here, in the hope that | |
161 | # would give better tracebacks than traceback.print_exc(). However, that | |
162 | # doesn't handle chained exceptions (with a __cause__ or __context__) well, | |
163 | # and I *think* the need for Failure() is reduced now that we mostly use | |
164 | # async/await. | |
165 | ||
166 | # Write the exception to both the logs *and* the unredirected stderr, | |
167 | # because people tend to get confused if it only goes to one or the other. | |
168 | # | |
169 | # One problem with this is that if people are using a logging config that | |
170 | # logs to the console (as is common eg under docker), they will get two | |
171 | # copies of the exception. We could maybe try to detect that, but it's | |
172 | # probably a cost we can bear. | |
173 | logger.fatal("Error during startup", exc_info=True) | |
174 | print("Error during startup:", file=sys.__stderr__) | |
175 | traceback.print_exc(file=sys.__stderr__) | |
176 | ||
177 | # it's no use calling sys.exit here, since that just raises a SystemExit | |
178 | # exception which is then caught by the reactor, and everything carries | |
179 | # on as normal. | |
180 | os._exit(1) | |
181 | ||
182 | reactor.callWhenRunning(lambda: defer.ensureDeferred(wrapper())) | |
143 | 183 | |
144 | 184 | |
145 | 185 | def listen_metrics(bind_addresses, port): |
226 | 266 | logger.info("Context factories updated.") |
227 | 267 | |
228 | 268 | |
229 | def start(hs: "synapse.server.HomeServer", listeners: Iterable[ListenerConfig]): | |
269 | async def start(hs: "synapse.server.HomeServer", listeners: Iterable[ListenerConfig]): | |
230 | 270 | """ |
231 | 271 | Start a Synapse server or worker. |
232 | 272 | |
240 | 280 | hs: homeserver instance |
241 | 281 | listeners: Listener configuration ('listeners' in homeserver.yaml) |
242 | 282 | """ |
243 | try: | |
244 | # Set up the SIGHUP machinery. | |
245 | if hasattr(signal, "SIGHUP"): | |
246 | ||
247 | reactor = hs.get_reactor() | |
248 | ||
249 | @wrap_as_background_process("sighup") | |
250 | def handle_sighup(*args, **kwargs): | |
251 | # Tell systemd our state, if we're using it. This will silently fail if | |
252 | # we're not using systemd. | |
253 | sdnotify(b"RELOADING=1") | |
254 | ||
255 | for i, args, kwargs in _sighup_callbacks: | |
256 | i(*args, **kwargs) | |
257 | ||
258 | sdnotify(b"READY=1") | |
259 | ||
260 | # We defer running the sighup handlers until next reactor tick. This | |
261 | # is so that we're in a sane state, e.g. flushing the logs may fail | |
262 | # if the sighup happens in the middle of writing a log entry. | |
263 | def run_sighup(*args, **kwargs): | |
264 | # `callFromThread` should be "signal safe" as well as thread | |
265 | # safe. | |
266 | reactor.callFromThread(handle_sighup, *args, **kwargs) | |
267 | ||
268 | signal.signal(signal.SIGHUP, run_sighup) | |
269 | ||
270 | register_sighup(refresh_certificate, hs) | |
271 | ||
272 | # Load the certificate from disk. | |
273 | refresh_certificate(hs) | |
274 | ||
275 | # Start the tracer | |
276 | synapse.logging.opentracing.init_tracer( # type: ignore[attr-defined] # noqa | |
277 | hs | |
278 | ) | |
279 | ||
280 | # It is now safe to start your Synapse. | |
281 | hs.start_listening(listeners) | |
282 | hs.get_datastore().db_pool.start_profiling() | |
283 | hs.get_pusherpool().start() | |
284 | ||
285 | # Log when we start the shut down process. | |
286 | hs.get_reactor().addSystemEventTrigger( | |
287 | "before", "shutdown", logger.info, "Shutting down..." | |
288 | ) | |
289 | ||
290 | setup_sentry(hs) | |
291 | setup_sdnotify(hs) | |
292 | ||
293 | # If background tasks are running on the main process, start collecting the | |
294 | # phone home stats. | |
295 | if hs.config.run_background_tasks: | |
296 | start_phone_stats_home(hs) | |
297 | ||
298 | # We now freeze all allocated objects in the hopes that (almost) | |
299 | # everything currently allocated are things that will be used for the | |
300 | # rest of time. Doing so means less work each GC (hopefully). | |
301 | # | |
302 | # This only works on Python 3.7 | |
303 | if sys.version_info >= (3, 7): | |
304 | gc.collect() | |
305 | gc.freeze() | |
306 | except Exception: | |
307 | traceback.print_exc(file=sys.stderr) | |
283 | # Set up the SIGHUP machinery. | |
284 | if hasattr(signal, "SIGHUP"): | |
308 | 285 | reactor = hs.get_reactor() |
309 | if reactor.running: | |
310 | reactor.stop() | |
311 | sys.exit(1) | |
286 | ||
287 | @wrap_as_background_process("sighup") | |
288 | def handle_sighup(*args, **kwargs): | |
289 | # Tell systemd our state, if we're using it. This will silently fail if | |
290 | # we're not using systemd. | |
291 | sdnotify(b"RELOADING=1") | |
292 | ||
293 | for i, args, kwargs in _sighup_callbacks: | |
294 | i(*args, **kwargs) | |
295 | ||
296 | sdnotify(b"READY=1") | |
297 | ||
298 | # We defer running the sighup handlers until next reactor tick. This | |
299 | # is so that we're in a sane state, e.g. flushing the logs may fail | |
300 | # if the sighup happens in the middle of writing a log entry. | |
301 | def run_sighup(*args, **kwargs): | |
302 | # `callFromThread` should be "signal safe" as well as thread | |
303 | # safe. | |
304 | reactor.callFromThread(handle_sighup, *args, **kwargs) | |
305 | ||
306 | signal.signal(signal.SIGHUP, run_sighup) | |
307 | ||
308 | register_sighup(refresh_certificate, hs) | |
309 | ||
310 | # Load the certificate from disk. | |
311 | refresh_certificate(hs) | |
312 | ||
313 | # Start the tracer | |
314 | synapse.logging.opentracing.init_tracer( # type: ignore[attr-defined] # noqa | |
315 | hs | |
316 | ) | |
317 | ||
318 | # It is now safe to start your Synapse. | |
319 | hs.start_listening(listeners) | |
320 | hs.get_datastore().db_pool.start_profiling() | |
321 | hs.get_pusherpool().start() | |
322 | ||
323 | # Log when we start the shut down process. | |
324 | hs.get_reactor().addSystemEventTrigger( | |
325 | "before", "shutdown", logger.info, "Shutting down..." | |
326 | ) | |
327 | ||
328 | setup_sentry(hs) | |
329 | setup_sdnotify(hs) | |
330 | ||
331 | # If background tasks are running on the main process, start collecting the | |
332 | # phone home stats. | |
333 | if hs.config.run_background_tasks: | |
334 | start_phone_stats_home(hs) | |
335 | ||
336 | # We now freeze all allocated objects in the hopes that (almost) | |
337 | # everything currently allocated are things that will be used for the | |
338 | # rest of time. Doing so means less work each GC (hopefully). | |
339 | # | |
340 | # This only works on Python 3.7 | |
341 | if sys.version_info >= (3, 7): | |
342 | gc.collect() | |
343 | gc.freeze() | |
312 | 344 | |
313 | 345 | |
314 | 346 | def setup_sentry(hs): |
20 | 20 | |
21 | 21 | from typing_extensions import ContextManager |
22 | 22 | |
23 | from twisted.internet import address, reactor | |
23 | from twisted.internet import address | |
24 | 24 | |
25 | 25 | import synapse |
26 | 26 | import synapse.events |
33 | 33 | SERVER_KEY_V2_PREFIX, |
34 | 34 | ) |
35 | 35 | from synapse.app import _base |
36 | from synapse.app._base import register_start | |
36 | 37 | from synapse.config._base import ConfigError |
37 | 38 | from synapse.config.homeserver import HomeServerConfig |
38 | 39 | from synapse.config.logger import setup_logging |
98 | 99 | ) |
99 | 100 | from synapse.rest.client.v1.push_rule import PushRuleRestServlet |
100 | 101 | from synapse.rest.client.v1.voip import VoipRestServlet |
101 | from synapse.rest.client.v2_alpha import groups, sync, user_directory | |
102 | from synapse.rest.client.v2_alpha import ( | |
103 | account_data, | |
104 | groups, | |
105 | read_marker, | |
106 | receipts, | |
107 | room_keys, | |
108 | sync, | |
109 | tags, | |
110 | user_directory, | |
111 | ) | |
102 | 112 | from synapse.rest.client.v2_alpha._base import client_patterns |
103 | 113 | from synapse.rest.client.v2_alpha.account import ThreepidRestServlet |
104 | 114 | from synapse.rest.client.v2_alpha.account_data import ( |
105 | 115 | AccountDataServlet, |
106 | 116 | RoomAccountDataServlet, |
107 | 117 | ) |
108 | from synapse.rest.client.v2_alpha.keys import KeyChangesServlet, KeyQueryServlet | |
118 | from synapse.rest.client.v2_alpha.devices import DevicesRestServlet | |
119 | from synapse.rest.client.v2_alpha.keys import ( | |
120 | KeyChangesServlet, | |
121 | KeyQueryServlet, | |
122 | OneTimeKeyServlet, | |
123 | ) | |
109 | 124 | from synapse.rest.client.v2_alpha.register import RegisterRestServlet |
125 | from synapse.rest.client.v2_alpha.sendtodevice import SendToDeviceRestServlet | |
110 | 126 | from synapse.rest.client.versions import VersionsRestServlet |
111 | 127 | from synapse.rest.health import HealthResource |
112 | 128 | from synapse.rest.key.v2 import KeyApiV2Resource |
113 | 129 | from synapse.server import HomeServer, cache_in_self |
114 | 130 | from synapse.storage.databases.main.censor_events import CensorEventsStore |
115 | 131 | from synapse.storage.databases.main.client_ips import ClientIpWorkerStore |
132 | from synapse.storage.databases.main.e2e_room_keys import EndToEndRoomKeyStore | |
116 | 133 | from synapse.storage.databases.main.media_repository import MediaRepositoryStore |
117 | 134 | from synapse.storage.databases.main.metrics import ServerMetricsStore |
118 | 135 | from synapse.storage.databases.main.monthly_active_users import ( |
444 | 461 | UserDirectoryStore, |
445 | 462 | StatsStore, |
446 | 463 | UIAuthWorkerStore, |
464 | EndToEndRoomKeyStore, | |
447 | 465 | SlavedDeviceInboxStore, |
448 | 466 | SlavedDeviceStore, |
449 | 467 | SlavedReceiptsStore, |
500 | 518 | RegisterRestServlet(self).register(resource) |
501 | 519 | LoginRestServlet(self).register(resource) |
502 | 520 | ThreepidRestServlet(self).register(resource) |
521 | DevicesRestServlet(self).register(resource) | |
503 | 522 | KeyQueryServlet(self).register(resource) |
523 | OneTimeKeyServlet(self).register(resource) | |
504 | 524 | KeyChangesServlet(self).register(resource) |
505 | 525 | VoipRestServlet(self).register(resource) |
506 | 526 | PushRuleRestServlet(self).register(resource) |
518 | 538 | room.register_servlets(self, resource, True) |
519 | 539 | room.register_deprecated_servlets(self, resource) |
520 | 540 | InitialSyncRestServlet(self).register(resource) |
541 | room_keys.register_servlets(self, resource) | |
542 | tags.register_servlets(self, resource) | |
543 | account_data.register_servlets(self, resource) | |
544 | receipts.register_servlets(self, resource) | |
545 | read_marker.register_servlets(self, resource) | |
546 | ||
547 | SendToDeviceRestServlet(self).register(resource) | |
521 | 548 | |
522 | 549 | user_directory.register_servlets(self, resource) |
523 | 550 | |
956 | 983 | # streams. Will no-op if no streams can be written to by this worker. |
957 | 984 | hs.get_replication_streamer() |
958 | 985 | |
959 | reactor.addSystemEventTrigger( | |
960 | "before", "startup", _base.start, hs, config.worker_listeners | |
961 | ) | |
986 | register_start(_base.start, hs, config.worker_listeners) | |
962 | 987 | |
963 | 988 | _base.start_worker_reactor("synapse-generic-worker", config) |
964 | 989 |
14 | 14 | # See the License for the specific language governing permissions and |
15 | 15 | # limitations under the License. |
16 | 16 | |
17 | import gc | |
18 | 17 | import logging |
19 | 18 | import os |
20 | 19 | import sys |
21 | 20 | from typing import Iterable, Iterator |
22 | 21 | |
23 | from twisted.application import service | |
24 | from twisted.internet import defer, reactor | |
25 | from twisted.python.failure import Failure | |
22 | from twisted.internet import reactor | |
26 | 23 | from twisted.web.resource import EncodingResourceWrapper, IResource |
27 | 24 | from twisted.web.server import GzipEncoderFactory |
28 | 25 | from twisted.web.static import File |
39 | 36 | WEB_CLIENT_PREFIX, |
40 | 37 | ) |
41 | 38 | from synapse.app import _base |
42 | from synapse.app._base import listen_ssl, listen_tcp, quit_with_error | |
39 | from synapse.app._base import listen_ssl, listen_tcp, quit_with_error, register_start | |
43 | 40 | from synapse.config._base import ConfigError |
44 | 41 | from synapse.config.emailconfig import ThreepidBehaviour |
45 | 42 | from synapse.config.homeserver import HomeServerConfig |
62 | 59 | from synapse.rest.admin import AdminRestResource |
63 | 60 | from synapse.rest.health import HealthResource |
64 | 61 | from synapse.rest.key.v2 import KeyApiV2Resource |
62 | from synapse.rest.synapse.client.pick_idp import PickIdpResource | |
65 | 63 | from synapse.rest.synapse.client.pick_username import pick_username_resource |
66 | 64 | from synapse.rest.well_known import WellKnownResource |
67 | 65 | from synapse.server import HomeServer |
71 | 69 | from synapse.util.httpresourcetree import create_resource_tree |
72 | 70 | from synapse.util.manhole import manhole |
73 | 71 | from synapse.util.module_loader import load_module |
74 | from synapse.util.rlimit import change_resource_limit | |
75 | 72 | from synapse.util.versionstring import get_version_string |
76 | 73 | |
77 | 74 | logger = logging.getLogger("synapse.app.homeserver") |
193 | 190 | "/.well-known/matrix/client": WellKnownResource(self), |
194 | 191 | "/_synapse/admin": AdminRestResource(self), |
195 | 192 | "/_synapse/client/pick_username": pick_username_resource(self), |
193 | "/_synapse/client/pick_idp": PickIdpResource(self), | |
196 | 194 | } |
197 | 195 | ) |
198 | 196 | |
414 | 412 | _base.refresh_certificate(hs) |
415 | 413 | |
416 | 414 | async def start(): |
417 | try: | |
418 | # Run the ACME provisioning code, if it's enabled. | |
419 | if hs.config.acme_enabled: | |
420 | acme = hs.get_acme_handler() | |
421 | # Start up the webservices which we will respond to ACME | |
422 | # challenges with, and then provision. | |
423 | await acme.start_listening() | |
424 | await do_acme() | |
425 | ||
426 | # Check if it needs to be reprovisioned every day. | |
427 | hs.get_clock().looping_call(reprovision_acme, 24 * 60 * 60 * 1000) | |
428 | ||
429 | # Load the OIDC provider metadatas, if OIDC is enabled. | |
430 | if hs.config.oidc_enabled: | |
431 | oidc = hs.get_oidc_handler() | |
432 | # Loading the provider metadata also ensures the provider config is valid. | |
433 | await oidc.load_metadata() | |
434 | await oidc.load_jwks() | |
435 | ||
436 | _base.start(hs, config.listeners) | |
437 | ||
438 | hs.get_datastore().db_pool.updates.start_doing_background_updates() | |
439 | except Exception: | |
440 | # Print the exception and bail out. | |
441 | print("Error during startup:", file=sys.stderr) | |
442 | ||
443 | # this gives better tracebacks than traceback.print_exc() | |
444 | Failure().printTraceback(file=sys.stderr) | |
445 | ||
446 | if reactor.running: | |
447 | reactor.stop() | |
448 | sys.exit(1) | |
449 | ||
450 | reactor.callWhenRunning(lambda: defer.ensureDeferred(start())) | |
415 | # Run the ACME provisioning code, if it's enabled. | |
416 | if hs.config.acme_enabled: | |
417 | acme = hs.get_acme_handler() | |
418 | # Start up the webservices which we will respond to ACME | |
419 | # challenges with, and then provision. | |
420 | await acme.start_listening() | |
421 | await do_acme() | |
422 | ||
423 | # Check if it needs to be reprovisioned every day. | |
424 | hs.get_clock().looping_call(reprovision_acme, 24 * 60 * 60 * 1000) | |
425 | ||
426 | # Load the OIDC provider metadatas, if OIDC is enabled. | |
427 | if hs.config.oidc_enabled: | |
428 | oidc = hs.get_oidc_handler() | |
429 | # Loading the provider metadata also ensures the provider config is valid. | |
430 | await oidc.load_metadata() | |
431 | ||
432 | await _base.start(hs, config.listeners) | |
433 | ||
434 | hs.get_datastore().db_pool.updates.start_doing_background_updates() | |
435 | ||
436 | register_start(start) | |
451 | 437 | |
452 | 438 | return hs |
453 | 439 | |
482 | 468 | indent += 1 |
483 | 469 | yield ":\n%s%s" % (" " * indent, str(e)) |
484 | 470 | e = e.__cause__ |
485 | ||
486 | ||
487 | class SynapseService(service.Service): | |
488 | """ | |
489 | A twisted Service class that will start synapse. Used to run synapse | |
490 | via twistd and a .tac. | |
491 | """ | |
492 | ||
493 | def __init__(self, config): | |
494 | self.config = config | |
495 | ||
496 | def startService(self): | |
497 | hs = setup(self.config) | |
498 | change_resource_limit(hs.config.soft_file_limit) | |
499 | if hs.config.gc_thresholds: | |
500 | gc.set_threshold(*hs.config.gc_thresholds) | |
501 | ||
502 | def stopService(self): | |
503 | return self._port.stopListening() | |
504 | 471 | |
505 | 472 | |
506 | 473 | def run(hs): |
251 | 251 | env = jinja2.Environment(loader=loader, autoescape=autoescape) |
252 | 252 | |
253 | 253 | # Update the environment with our custom filters |
254 | env.filters.update({"format_ts": _format_ts_filter}) | |
255 | if self.public_baseurl: | |
256 | env.filters.update( | |
257 | {"mxc_to_http": _create_mxc_to_http_filter(self.public_baseurl)} | |
258 | ) | |
254 | env.filters.update( | |
255 | { | |
256 | "format_ts": _format_ts_filter, | |
257 | "mxc_to_http": _create_mxc_to_http_filter(self.public_baseurl), | |
258 | } | |
259 | ) | |
259 | 260 | |
260 | 261 | for filename in filenames: |
261 | 262 | # Load the template |
55 | 55 | """ |
56 | 56 | # copy `config_path` before modifying it. |
57 | 57 | path = list(config_path) |
58 | for p in list(e.path): | |
58 | for p in list(e.absolute_path): | |
59 | 59 | if isinstance(p, int): |
60 | 60 | path.append("<item %i>" % p) |
61 | 61 | else: |
39 | 39 | self.cas_required_attributes = {} |
40 | 40 | |
41 | 41 | def generate_config_section(self, config_dir_path, server_name, **kwargs): |
42 | return """ | |
42 | return """\ | |
43 | 43 | # Enable Central Authentication Service (CAS) for registration and login. |
44 | 44 | # |
45 | 45 | cas_config: |
164 | 164 | missing = [] |
165 | 165 | if not self.email_notif_from: |
166 | 166 | missing.append("email.notif_from") |
167 | ||
168 | # public_baseurl is required to build password reset and validation links that | |
169 | # will be emailed to users | |
170 | if config.get("public_baseurl") is None: | |
171 | missing.append("public_baseurl") | |
172 | 167 | |
173 | 168 | if missing: |
174 | 169 | raise ConfigError( |
268 | 263 | if not self.email_notif_from: |
269 | 264 | missing.append("email.notif_from") |
270 | 265 | |
271 | if config.get("public_baseurl") is None: | |
272 | missing.append("public_baseurl") | |
273 | ||
274 | 266 | if missing: |
275 | 267 | raise ConfigError( |
276 | 268 | "email.enable_notifs is True but required keys are missing: %s" |
0 | 0 | # -*- coding: utf-8 -*- |
1 | 1 | # Copyright 2020 Quentin Gliech |
2 | # Copyright 2020-2021 The Matrix.org Foundation C.I.C. | |
2 | 3 | # |
3 | 4 | # Licensed under the Apache License, Version 2.0 (the "License"); |
4 | 5 | # you may not use this file except in compliance with the License. |
12 | 13 | # See the License for the specific language governing permissions and |
13 | 14 | # limitations under the License. |
14 | 15 | |
16 | import string | |
17 | from typing import Iterable, Optional, Tuple, Type | |
18 | ||
19 | import attr | |
20 | ||
21 | from synapse.config._util import validate_config | |
15 | 22 | from synapse.python_dependencies import DependencyException, check_requirements |
23 | from synapse.types import Collection, JsonDict | |
16 | 24 | from synapse.util.module_loader import load_module |
25 | from synapse.util.stringutils import parse_and_validate_mxc_uri | |
17 | 26 | |
18 | 27 | from ._base import Config, ConfigError |
19 | 28 | |
24 | 33 | section = "oidc" |
25 | 34 | |
26 | 35 | def read_config(self, config, **kwargs): |
27 | self.oidc_enabled = False | |
28 | ||
29 | oidc_config = config.get("oidc_config") | |
30 | ||
31 | if not oidc_config or not oidc_config.get("enabled", False): | |
36 | self.oidc_providers = tuple(_parse_oidc_provider_configs(config)) | |
37 | if not self.oidc_providers: | |
32 | 38 | return |
33 | 39 | |
34 | 40 | try: |
35 | 41 | check_requirements("oidc") |
36 | 42 | except DependencyException as e: |
37 | raise ConfigError(e.message) | |
43 | raise ConfigError(e.message) from e | |
38 | 44 | |
39 | 45 | public_baseurl = self.public_baseurl |
40 | if public_baseurl is None: | |
41 | raise ConfigError("oidc_config requires a public_baseurl to be set") | |
42 | 46 | self.oidc_callback_url = public_baseurl + "_synapse/oidc/callback" |
43 | 47 | |
44 | self.oidc_enabled = True | |
45 | self.oidc_discover = oidc_config.get("discover", True) | |
46 | self.oidc_issuer = oidc_config["issuer"] | |
47 | self.oidc_client_id = oidc_config["client_id"] | |
48 | self.oidc_client_secret = oidc_config["client_secret"] | |
49 | self.oidc_client_auth_method = oidc_config.get( | |
50 | "client_auth_method", "client_secret_basic" | |
51 | ) | |
52 | self.oidc_scopes = oidc_config.get("scopes", ["openid"]) | |
53 | self.oidc_authorization_endpoint = oidc_config.get("authorization_endpoint") | |
54 | self.oidc_token_endpoint = oidc_config.get("token_endpoint") | |
55 | self.oidc_userinfo_endpoint = oidc_config.get("userinfo_endpoint") | |
56 | self.oidc_jwks_uri = oidc_config.get("jwks_uri") | |
57 | self.oidc_skip_verification = oidc_config.get("skip_verification", False) | |
58 | self.oidc_user_profile_method = oidc_config.get("user_profile_method", "auto") | |
59 | self.oidc_allow_existing_users = oidc_config.get("allow_existing_users", False) | |
60 | ||
61 | ump_config = oidc_config.get("user_mapping_provider", {}) | |
62 | ump_config.setdefault("module", DEFAULT_USER_MAPPING_PROVIDER) | |
63 | ump_config.setdefault("config", {}) | |
64 | ||
65 | ( | |
66 | self.oidc_user_mapping_provider_class, | |
67 | self.oidc_user_mapping_provider_config, | |
68 | ) = load_module(ump_config, ("oidc_config", "user_mapping_provider")) | |
69 | ||
70 | # Ensure loaded user mapping module has defined all necessary methods | |
71 | required_methods = [ | |
72 | "get_remote_user_id", | |
73 | "map_user_attributes", | |
74 | ] | |
75 | missing_methods = [ | |
76 | method | |
77 | for method in required_methods | |
78 | if not hasattr(self.oidc_user_mapping_provider_class, method) | |
79 | ] | |
80 | if missing_methods: | |
81 | raise ConfigError( | |
82 | "Class specified by oidc_config." | |
83 | "user_mapping_provider.module is missing required " | |
84 | "methods: %s" % (", ".join(missing_methods),) | |
85 | ) | |
48 | @property | |
49 | def oidc_enabled(self) -> bool: | |
50 | # OIDC is enabled if we have a provider | |
51 | return bool(self.oidc_providers) | |
86 | 52 | |
87 | 53 | def generate_config_section(self, config_dir_path, server_name, **kwargs): |
88 | 54 | return """\ |
89 | # Enable OpenID Connect (OIDC) / OAuth 2.0 for registration and login. | |
55 | # List of OpenID Connect (OIDC) / OAuth 2.0 identity providers, for registration | |
56 | # and login. | |
57 | # | |
58 | # Options for each entry include: | |
59 | # | |
60 | # idp_id: a unique identifier for this identity provider. Used internally | |
61 | # by Synapse; should be a single word such as 'github'. | |
62 | # | |
63 | # Note that, if this is changed, users authenticating via that provider | |
64 | # will no longer be recognised as the same user! | |
65 | # | |
66 | # idp_name: A user-facing name for this identity provider, which is used to | |
67 | # offer the user a choice of login mechanisms. | |
68 | # | |
69 | # idp_icon: An optional icon for this identity provider, which is presented | |
70 | # by identity picker pages. If given, must be an MXC URI of the format | |
71 | # mxc://<server-name>/<media-id>. (An easy way to obtain such an MXC URI | |
72 | # is to upload an image to an (unencrypted) room and then copy the "url" | |
73 | # from the source of the event.) | |
74 | # | |
75 | # discover: set to 'false' to disable the use of the OIDC discovery mechanism | |
76 | # to discover endpoints. Defaults to true. | |
77 | # | |
78 | # issuer: Required. The OIDC issuer. Used to validate tokens and (if discovery | |
79 | # is enabled) to discover the provider's endpoints. | |
80 | # | |
81 | # client_id: Required. oauth2 client id to use. | |
82 | # | |
83 | # client_secret: Required. oauth2 client secret to use. | |
84 | # | |
85 | # client_auth_method: auth method to use when exchanging the token. Valid | |
86 | # values are 'client_secret_basic' (default), 'client_secret_post' and | |
87 | # 'none'. | |
88 | # | |
89 | # scopes: list of scopes to request. This should normally include the "openid" | |
90 | # scope. Defaults to ["openid"]. | |
91 | # | |
92 | # authorization_endpoint: the oauth2 authorization endpoint. Required if | |
93 | # provider discovery is disabled. | |
94 | # | |
95 | # token_endpoint: the oauth2 token endpoint. Required if provider discovery is | |
96 | # disabled. | |
97 | # | |
98 | # userinfo_endpoint: the OIDC userinfo endpoint. Required if discovery is | |
99 | # disabled and the 'openid' scope is not requested. | |
100 | # | |
101 | # jwks_uri: URI where to fetch the JWKS. Required if discovery is disabled and | |
102 | # the 'openid' scope is used. | |
103 | # | |
104 | # skip_verification: set to 'true' to skip metadata verification. Use this if | |
105 | # you are connecting to a provider that is not OpenID Connect compliant. | |
106 | # Defaults to false. Avoid this in production. | |
107 | # | |
108 | # user_profile_method: Whether to fetch the user profile from the userinfo | |
109 | # endpoint. Valid values are: 'auto' or 'userinfo_endpoint'. | |
110 | # | |
111 | # Defaults to 'auto', which fetches the userinfo endpoint if 'openid' is | |
112 | # included in 'scopes'. Set to 'userinfo_endpoint' to always fetch the | |
113 | # userinfo endpoint. | |
114 | # | |
115 | # allow_existing_users: set to 'true' to allow a user logging in via OIDC to | |
116 | # match a pre-existing account instead of failing. This could be used if | |
117 | # switching from password logins to OIDC. Defaults to false. | |
118 | # | |
119 | # user_mapping_provider: Configuration for how attributes returned from a OIDC | |
120 | # provider are mapped onto a matrix user. This setting has the following | |
121 | # sub-properties: | |
122 | # | |
123 | # module: The class name of a custom mapping module. Default is | |
124 | # {mapping_provider!r}. | |
125 | # See https://github.com/matrix-org/synapse/blob/master/docs/sso_mapping_providers.md#openid-mapping-providers | |
126 | # for information on implementing a custom mapping provider. | |
127 | # | |
128 | # config: Configuration for the mapping provider module. This section will | |
129 | # be passed as a Python dictionary to the user mapping provider | |
130 | # module's `parse_config` method. | |
131 | # | |
132 | # For the default provider, the following settings are available: | |
133 | # | |
134 | # sub: name of the claim containing a unique identifier for the | |
135 | # user. Defaults to 'sub', which OpenID Connect compliant | |
136 | # providers should provide. | |
137 | # | |
138 | # localpart_template: Jinja2 template for the localpart of the MXID. | |
139 | # If this is not set, the user will be prompted to choose their | |
140 | # own username. | |
141 | # | |
142 | # display_name_template: Jinja2 template for the display name to set | |
143 | # on first login. If unset, no displayname will be set. | |
144 | # | |
145 | # extra_attributes: a map of Jinja2 templates for extra attributes | |
146 | # to send back to the client during login. | |
147 | # Note that these are non-standard and clients will ignore them | |
148 | # without modifications. | |
149 | # | |
150 | # When rendering, the Jinja2 templates are given a 'user' variable, | |
151 | # which is set to the claims returned by the UserInfo Endpoint and/or | |
152 | # in the ID Token. | |
90 | 153 | # |
91 | 154 | # See https://github.com/matrix-org/synapse/blob/master/docs/openid.md |
92 | # for some example configurations. | |
93 | # | |
94 | oidc_config: | |
95 | # Uncomment the following to enable authorization against an OpenID Connect | |
96 | # server. Defaults to false. | |
155 | # for information on how to configure these options. | |
156 | # | |
157 | # For backwards compatibility, it is also possible to configure a single OIDC | |
158 | # provider via an 'oidc_config' setting. This is now deprecated and admins are | |
159 | # advised to migrate to the 'oidc_providers' format. (When doing that migration, | |
160 | # use 'oidc' for the idp_id to ensure that existing users continue to be | |
161 | # recognised.) | |
162 | # | |
163 | oidc_providers: | |
164 | # Generic example | |
97 | 165 | # |
98 | #enabled: true | |
99 | ||
100 | # Uncomment the following to disable use of the OIDC discovery mechanism to | |
101 | # discover endpoints. Defaults to true. | |
166 | #- idp_id: my_idp | |
167 | # idp_name: "My OpenID provider" | |
168 | # idp_icon: "mxc://example.com/mediaid" | |
169 | # discover: false | |
170 | # issuer: "https://accounts.example.com/" | |
171 | # client_id: "provided-by-your-issuer" | |
172 | # client_secret: "provided-by-your-issuer" | |
173 | # client_auth_method: client_secret_post | |
174 | # scopes: ["openid", "profile"] | |
175 | # authorization_endpoint: "https://accounts.example.com/oauth2/auth" | |
176 | # token_endpoint: "https://accounts.example.com/oauth2/token" | |
177 | # userinfo_endpoint: "https://accounts.example.com/userinfo" | |
178 | # jwks_uri: "https://accounts.example.com/.well-known/jwks.json" | |
179 | # skip_verification: true | |
180 | ||
181 | # For use with Keycloak | |
102 | 182 | # |
103 | #discover: false | |
104 | ||
105 | # the OIDC issuer. Used to validate tokens and (if discovery is enabled) to | |
106 | # discover the provider's endpoints. | |
183 | #- idp_id: keycloak | |
184 | # idp_name: Keycloak | |
185 | # issuer: "https://127.0.0.1:8443/auth/realms/my_realm_name" | |
186 | # client_id: "synapse" | |
187 | # client_secret: "copy secret generated in Keycloak UI" | |
188 | # scopes: ["openid", "profile"] | |
189 | ||
190 | # For use with Github | |
107 | 191 | # |
108 | # Required if 'enabled' is true. | |
109 | # | |
110 | #issuer: "https://accounts.example.com/" | |
111 | ||
112 | # oauth2 client id to use. | |
113 | # | |
114 | # Required if 'enabled' is true. | |
115 | # | |
116 | #client_id: "provided-by-your-issuer" | |
117 | ||
118 | # oauth2 client secret to use. | |
119 | # | |
120 | # Required if 'enabled' is true. | |
121 | # | |
122 | #client_secret: "provided-by-your-issuer" | |
123 | ||
124 | # auth method to use when exchanging the token. | |
125 | # Valid values are 'client_secret_basic' (default), 'client_secret_post' and | |
126 | # 'none'. | |
127 | # | |
128 | #client_auth_method: client_secret_post | |
129 | ||
130 | # list of scopes to request. This should normally include the "openid" scope. | |
131 | # Defaults to ["openid"]. | |
132 | # | |
133 | #scopes: ["openid", "profile"] | |
134 | ||
135 | # the oauth2 authorization endpoint. Required if provider discovery is disabled. | |
136 | # | |
137 | #authorization_endpoint: "https://accounts.example.com/oauth2/auth" | |
138 | ||
139 | # the oauth2 token endpoint. Required if provider discovery is disabled. | |
140 | # | |
141 | #token_endpoint: "https://accounts.example.com/oauth2/token" | |
142 | ||
143 | # the OIDC userinfo endpoint. Required if discovery is disabled and the | |
144 | # "openid" scope is not requested. | |
145 | # | |
146 | #userinfo_endpoint: "https://accounts.example.com/userinfo" | |
147 | ||
148 | # URI where to fetch the JWKS. Required if discovery is disabled and the | |
149 | # "openid" scope is used. | |
150 | # | |
151 | #jwks_uri: "https://accounts.example.com/.well-known/jwks.json" | |
152 | ||
153 | # Uncomment to skip metadata verification. Defaults to false. | |
154 | # | |
155 | # Use this if you are connecting to a provider that is not OpenID Connect | |
156 | # compliant. | |
157 | # Avoid this in production. | |
158 | # | |
159 | #skip_verification: true | |
160 | ||
161 | # Whether to fetch the user profile from the userinfo endpoint. Valid | |
162 | # values are: "auto" or "userinfo_endpoint". | |
163 | # | |
164 | # Defaults to "auto", which fetches the userinfo endpoint if "openid" is included | |
165 | # in `scopes`. Uncomment the following to always fetch the userinfo endpoint. | |
166 | # | |
167 | #user_profile_method: "userinfo_endpoint" | |
168 | ||
169 | # Uncomment to allow a user logging in via OIDC to match a pre-existing account instead | |
170 | # of failing. This could be used if switching from password logins to OIDC. Defaults to false. | |
171 | # | |
172 | #allow_existing_users: true | |
173 | ||
174 | # An external module can be provided here as a custom solution to mapping | |
175 | # attributes returned from a OIDC provider onto a matrix user. | |
176 | # | |
177 | user_mapping_provider: | |
178 | # The custom module's class. Uncomment to use a custom module. | |
179 | # Default is {mapping_provider!r}. | |
180 | # | |
181 | # See https://github.com/matrix-org/synapse/blob/master/docs/sso_mapping_providers.md#openid-mapping-providers | |
182 | # for information on implementing a custom mapping provider. | |
183 | # | |
184 | #module: mapping_provider.OidcMappingProvider | |
185 | ||
186 | # Custom configuration values for the module. This section will be passed as | |
187 | # a Python dictionary to the user mapping provider module's `parse_config` | |
188 | # method. | |
189 | # | |
190 | # The examples below are intended for the default provider: they should be | |
191 | # changed if using a custom provider. | |
192 | # | |
193 | config: | |
194 | # name of the claim containing a unique identifier for the user. | |
195 | # Defaults to `sub`, which OpenID Connect compliant providers should provide. | |
196 | # | |
197 | #subject_claim: "sub" | |
198 | ||
199 | # Jinja2 template for the localpart of the MXID. | |
200 | # | |
201 | # When rendering, this template is given the following variables: | |
202 | # * user: The claims returned by the UserInfo Endpoint and/or in the ID | |
203 | # Token | |
204 | # | |
205 | # If this is not set, the user will be prompted to choose their | |
206 | # own username. | |
207 | # | |
208 | #localpart_template: "{{{{ user.preferred_username }}}}" | |
209 | ||
210 | # Jinja2 template for the display name to set on first login. | |
211 | # | |
212 | # If unset, no displayname will be set. | |
213 | # | |
214 | #display_name_template: "{{{{ user.given_name }}}} {{{{ user.last_name }}}}" | |
215 | ||
216 | # Jinja2 templates for extra attributes to send back to the client during | |
217 | # login. | |
218 | # | |
219 | # Note that these are non-standard and clients will ignore them without modifications. | |
220 | # | |
221 | #extra_attributes: | |
222 | #birthdate: "{{{{ user.birthdate }}}}" | |
192 | #- idp_id: github | |
193 | # idp_name: Github | |
194 | # discover: false | |
195 | # issuer: "https://github.com/" | |
196 | # client_id: "your-client-id" # TO BE FILLED | |
197 | # client_secret: "your-client-secret" # TO BE FILLED | |
198 | # authorization_endpoint: "https://github.com/login/oauth/authorize" | |
199 | # token_endpoint: "https://github.com/login/oauth/access_token" | |
200 | # userinfo_endpoint: "https://api.github.com/user" | |
201 | # scopes: ["read:user"] | |
202 | # user_mapping_provider: | |
203 | # config: | |
204 | # subject_claim: "id" | |
205 | # localpart_template: "{{ user.login }}" | |
206 | # display_name_template: "{{ user.name }}" | |
223 | 207 | """.format( |
224 | 208 | mapping_provider=DEFAULT_USER_MAPPING_PROVIDER |
225 | 209 | ) |
210 | ||
211 | ||
212 | # jsonschema definition of the configuration settings for an oidc identity provider | |
213 | OIDC_PROVIDER_CONFIG_SCHEMA = { | |
214 | "type": "object", | |
215 | "required": ["issuer", "client_id", "client_secret"], | |
216 | "properties": { | |
217 | # TODO: fix the maxLength here depending on what MSC2528 decides | |
218 | # remember that we prefix the ID given here with `oidc-` | |
219 | "idp_id": {"type": "string", "minLength": 1, "maxLength": 128}, | |
220 | "idp_name": {"type": "string"}, | |
221 | "idp_icon": {"type": "string"}, | |
222 | "discover": {"type": "boolean"}, | |
223 | "issuer": {"type": "string"}, | |
224 | "client_id": {"type": "string"}, | |
225 | "client_secret": {"type": "string"}, | |
226 | "client_auth_method": { | |
227 | "type": "string", | |
228 | # the following list is the same as the keys of | |
229 | # authlib.oauth2.auth.ClientAuth.DEFAULT_AUTH_METHODS. We inline it | |
230 | # to avoid importing authlib here. | |
231 | "enum": ["client_secret_basic", "client_secret_post", "none"], | |
232 | }, | |
233 | "scopes": {"type": "array", "items": {"type": "string"}}, | |
234 | "authorization_endpoint": {"type": "string"}, | |
235 | "token_endpoint": {"type": "string"}, | |
236 | "userinfo_endpoint": {"type": "string"}, | |
237 | "jwks_uri": {"type": "string"}, | |
238 | "skip_verification": {"type": "boolean"}, | |
239 | "user_profile_method": { | |
240 | "type": "string", | |
241 | "enum": ["auto", "userinfo_endpoint"], | |
242 | }, | |
243 | "allow_existing_users": {"type": "boolean"}, | |
244 | "user_mapping_provider": {"type": ["object", "null"]}, | |
245 | }, | |
246 | } | |
247 | ||
248 | # the same as OIDC_PROVIDER_CONFIG_SCHEMA, but with compulsory idp_id and idp_name | |
249 | OIDC_PROVIDER_CONFIG_WITH_ID_SCHEMA = { | |
250 | "allOf": [OIDC_PROVIDER_CONFIG_SCHEMA, {"required": ["idp_id", "idp_name"]}] | |
251 | } | |
252 | ||
253 | ||
254 | # the `oidc_providers` list can either be None (as it is in the default config), or | |
255 | # a list of provider configs, each of which requires an explicit ID and name. | |
256 | OIDC_PROVIDER_LIST_SCHEMA = { | |
257 | "oneOf": [ | |
258 | {"type": "null"}, | |
259 | {"type": "array", "items": OIDC_PROVIDER_CONFIG_WITH_ID_SCHEMA}, | |
260 | ] | |
261 | } | |
262 | ||
263 | # the `oidc_config` setting can either be None (which it used to be in the default | |
264 | # config), or an object. If an object, it is ignored unless it has an "enabled: True" | |
265 | # property. | |
266 | # | |
267 | # It's *possible* to represent this with jsonschema, but the resultant errors aren't | |
268 | # particularly clear, so we just check for either an object or a null here, and do | |
269 | # additional checks in the code. | |
270 | OIDC_CONFIG_SCHEMA = {"oneOf": [{"type": "null"}, {"type": "object"}]} | |
271 | ||
272 | # the top-level schema can contain an "oidc_config" and/or an "oidc_providers". | |
273 | MAIN_CONFIG_SCHEMA = { | |
274 | "type": "object", | |
275 | "properties": { | |
276 | "oidc_config": OIDC_CONFIG_SCHEMA, | |
277 | "oidc_providers": OIDC_PROVIDER_LIST_SCHEMA, | |
278 | }, | |
279 | } | |
280 | ||
281 | ||
282 | def _parse_oidc_provider_configs(config: JsonDict) -> Iterable["OidcProviderConfig"]: | |
283 | """extract and parse the OIDC provider configs from the config dict | |
284 | ||
285 | The configuration may contain either a single `oidc_config` object with an | |
286 | `enabled: True` property, or a list of provider configurations under | |
287 | `oidc_providers`, *or both*. | |
288 | ||
289 | Returns a generator which yields the OidcProviderConfig objects | |
290 | """ | |
291 | validate_config(MAIN_CONFIG_SCHEMA, config, ()) | |
292 | ||
293 | for i, p in enumerate(config.get("oidc_providers") or []): | |
294 | yield _parse_oidc_config_dict(p, ("oidc_providers", "<item %i>" % (i,))) | |
295 | ||
296 | # for backwards-compatibility, it is also possible to provide a single "oidc_config" | |
297 | # object with an "enabled: True" property. | |
298 | oidc_config = config.get("oidc_config") | |
299 | if oidc_config and oidc_config.get("enabled", False): | |
300 | # MAIN_CONFIG_SCHEMA checks that `oidc_config` is an object, but not that | |
301 | # it matches OIDC_PROVIDER_CONFIG_SCHEMA (see the comments on OIDC_CONFIG_SCHEMA | |
302 | # above), so now we need to validate it. | |
303 | validate_config(OIDC_PROVIDER_CONFIG_SCHEMA, oidc_config, ("oidc_config",)) | |
304 | yield _parse_oidc_config_dict(oidc_config, ("oidc_config",)) | |
305 | ||
306 | ||
307 | def _parse_oidc_config_dict( | |
308 | oidc_config: JsonDict, config_path: Tuple[str, ...] | |
309 | ) -> "OidcProviderConfig": | |
310 | """Take the configuration dict and parse it into an OidcProviderConfig | |
311 | ||
312 | Raises: | |
313 | ConfigError if the configuration is malformed. | |
314 | """ | |
315 | ump_config = oidc_config.get("user_mapping_provider", {}) | |
316 | ump_config.setdefault("module", DEFAULT_USER_MAPPING_PROVIDER) | |
317 | ump_config.setdefault("config", {}) | |
318 | ||
319 | (user_mapping_provider_class, user_mapping_provider_config,) = load_module( | |
320 | ump_config, config_path + ("user_mapping_provider",) | |
321 | ) | |
322 | ||
323 | # Ensure loaded user mapping module has defined all necessary methods | |
324 | required_methods = [ | |
325 | "get_remote_user_id", | |
326 | "map_user_attributes", | |
327 | ] | |
328 | missing_methods = [ | |
329 | method | |
330 | for method in required_methods | |
331 | if not hasattr(user_mapping_provider_class, method) | |
332 | ] | |
333 | if missing_methods: | |
334 | raise ConfigError( | |
335 | "Class %s is missing required " | |
336 | "methods: %s" % (user_mapping_provider_class, ", ".join(missing_methods),), | |
337 | config_path + ("user_mapping_provider", "module"), | |
338 | ) | |
339 | ||
340 | # MSC2858 will apply certain limits in what can be used as an IdP id, so let's | |
341 | # enforce those limits now. | |
342 | # TODO: factor out this stuff to a generic function | |
343 | idp_id = oidc_config.get("idp_id", "oidc") | |
344 | ||
345 | # TODO: update this validity check based on what MSC2858 decides. | |
346 | valid_idp_chars = set(string.ascii_lowercase + string.digits + "-._") | |
347 | ||
348 | if any(c not in valid_idp_chars for c in idp_id): | |
349 | raise ConfigError( | |
350 | 'idp_id may only contain a-z, 0-9, "-", ".", "_"', | |
351 | config_path + ("idp_id",), | |
352 | ) | |
353 | ||
354 | if idp_id[0] not in string.ascii_lowercase: | |
355 | raise ConfigError( | |
356 | "idp_id must start with a-z", config_path + ("idp_id",), | |
357 | ) | |
358 | ||
359 | # prefix the given IDP with a prefix specific to the SSO mechanism, to avoid | |
360 | # clashes with other mechs (such as SAML, CAS). | |
361 | # | |
362 | # We allow "oidc" as an exception so that people migrating from old-style | |
363 | # "oidc_config" format (which has long used "oidc" as its idp_id) can migrate to | |
364 | # a new-style "oidc_providers" entry without changing the idp_id for their provider | |
365 | # (and thereby invalidating their user_external_ids data). | |
366 | ||
367 | if idp_id != "oidc": | |
368 | idp_id = "oidc-" + idp_id | |
369 | ||
370 | # MSC2858 also specifies that the idp_icon must be a valid MXC uri | |
371 | idp_icon = oidc_config.get("idp_icon") | |
372 | if idp_icon is not None: | |
373 | try: | |
374 | parse_and_validate_mxc_uri(idp_icon) | |
375 | except ValueError as e: | |
376 | raise ConfigError( | |
377 | "idp_icon must be a valid MXC URI", config_path + ("idp_icon",) | |
378 | ) from e | |
379 | ||
380 | return OidcProviderConfig( | |
381 | idp_id=idp_id, | |
382 | idp_name=oidc_config.get("idp_name", "OIDC"), | |
383 | idp_icon=idp_icon, | |
384 | discover=oidc_config.get("discover", True), | |
385 | issuer=oidc_config["issuer"], | |
386 | client_id=oidc_config["client_id"], | |
387 | client_secret=oidc_config["client_secret"], | |
388 | client_auth_method=oidc_config.get("client_auth_method", "client_secret_basic"), | |
389 | scopes=oidc_config.get("scopes", ["openid"]), | |
390 | authorization_endpoint=oidc_config.get("authorization_endpoint"), | |
391 | token_endpoint=oidc_config.get("token_endpoint"), | |
392 | userinfo_endpoint=oidc_config.get("userinfo_endpoint"), | |
393 | jwks_uri=oidc_config.get("jwks_uri"), | |
394 | skip_verification=oidc_config.get("skip_verification", False), | |
395 | user_profile_method=oidc_config.get("user_profile_method", "auto"), | |
396 | allow_existing_users=oidc_config.get("allow_existing_users", False), | |
397 | user_mapping_provider_class=user_mapping_provider_class, | |
398 | user_mapping_provider_config=user_mapping_provider_config, | |
399 | ) | |
400 | ||
401 | ||
402 | @attr.s(slots=True, frozen=True) | |
403 | class OidcProviderConfig: | |
404 | # a unique identifier for this identity provider. Used in the 'user_external_ids' | |
405 | # table, as well as the query/path parameter used in the login protocol. | |
406 | idp_id = attr.ib(type=str) | |
407 | ||
408 | # user-facing name for this identity provider. | |
409 | idp_name = attr.ib(type=str) | |
410 | ||
411 | # Optional MXC URI for icon for this IdP. | |
412 | idp_icon = attr.ib(type=Optional[str]) | |
413 | ||
414 | # whether the OIDC discovery mechanism is used to discover endpoints | |
415 | discover = attr.ib(type=bool) | |
416 | ||
417 | # the OIDC issuer. Used to validate tokens and (if discovery is enabled) to | |
418 | # discover the provider's endpoints. | |
419 | issuer = attr.ib(type=str) | |
420 | ||
421 | # oauth2 client id to use | |
422 | client_id = attr.ib(type=str) | |
423 | ||
424 | # oauth2 client secret to use | |
425 | client_secret = attr.ib(type=str) | |
426 | ||
427 | # auth method to use when exchanging the token. | |
428 | # Valid values are 'client_secret_basic', 'client_secret_post' and | |
429 | # 'none'. | |
430 | client_auth_method = attr.ib(type=str) | |
431 | ||
432 | # list of scopes to request | |
433 | scopes = attr.ib(type=Collection[str]) | |
434 | ||
435 | # the oauth2 authorization endpoint. Required if discovery is disabled. | |
436 | authorization_endpoint = attr.ib(type=Optional[str]) | |
437 | ||
438 | # the oauth2 token endpoint. Required if discovery is disabled. | |
439 | token_endpoint = attr.ib(type=Optional[str]) | |
440 | ||
441 | # the OIDC userinfo endpoint. Required if discovery is disabled and the | |
442 | # "openid" scope is not requested. | |
443 | userinfo_endpoint = attr.ib(type=Optional[str]) | |
444 | ||
445 | # URI where to fetch the JWKS. Required if discovery is disabled and the | |
446 | # "openid" scope is used. | |
447 | jwks_uri = attr.ib(type=Optional[str]) | |
448 | ||
449 | # Whether to skip metadata verification | |
450 | skip_verification = attr.ib(type=bool) | |
451 | ||
452 | # Whether to fetch the user profile from the userinfo endpoint. Valid | |
453 | # values are: "auto" or "userinfo_endpoint". | |
454 | user_profile_method = attr.ib(type=str) | |
455 | ||
456 | # whether to allow a user logging in via OIDC to match a pre-existing account | |
457 | # instead of failing | |
458 | allow_existing_users = attr.ib(type=bool) | |
459 | ||
460 | # the class of the user mapping provider | |
461 | user_mapping_provider_class = attr.ib(type=Type) | |
462 | ||
463 | # the config of the user mapping provider | |
464 | user_mapping_provider_config = attr.ib() |
13 | 13 | # limitations under the License. |
14 | 14 | |
15 | 15 | import os |
16 | from distutils.util import strtobool | |
17 | 16 | |
18 | 17 | import pkg_resources |
19 | 18 | |
20 | 19 | from synapse.api.constants import RoomCreationPreset |
21 | 20 | from synapse.config._base import Config, ConfigError |
22 | 21 | from synapse.types import RoomAlias, UserID |
23 | from synapse.util.stringutils import random_string_with_symbols | |
22 | from synapse.util.stringutils import random_string_with_symbols, strtobool | |
24 | 23 | |
25 | 24 | |
26 | 25 | class AccountValidityConfig(Config): |
49 | 48 | |
50 | 49 | self.startup_job_max_delta = self.period * 10.0 / 100.0 |
51 | 50 | |
52 | if self.renew_by_email_enabled: | |
53 | if "public_baseurl" not in synapse_config: | |
54 | raise ConfigError("Can't send renewal emails without 'public_baseurl'") | |
55 | ||
56 | 51 | template_dir = config.get("template_dir") |
57 | 52 | |
58 | 53 | if not template_dir: |
85 | 80 | section = "registration" |
86 | 81 | |
87 | 82 | def read_config(self, config, **kwargs): |
88 | self.enable_registration = bool( | |
89 | strtobool(str(config.get("enable_registration", False))) | |
83 | self.enable_registration = strtobool( | |
84 | str(config.get("enable_registration", False)) | |
90 | 85 | ) |
91 | 86 | if "disable_registration" in config: |
92 | self.enable_registration = not bool( | |
93 | strtobool(str(config["disable_registration"])) | |
87 | self.enable_registration = not strtobool( | |
88 | str(config["disable_registration"]) | |
94 | 89 | ) |
95 | 90 | |
96 | 91 | self.account_validity = AccountValidityConfig( |
109 | 104 | account_threepid_delegates = config.get("account_threepid_delegates") or {} |
110 | 105 | self.account_threepid_delegate_email = account_threepid_delegates.get("email") |
111 | 106 | self.account_threepid_delegate_msisdn = account_threepid_delegates.get("msisdn") |
112 | if self.account_threepid_delegate_msisdn and not self.public_baseurl: | |
113 | raise ConfigError( | |
114 | "The configuration option `public_baseurl` is required if " | |
115 | "`account_threepid_delegate.msisdn` is set, such that " | |
116 | "clients know where to submit validation tokens to. Please " | |
117 | "configure `public_baseurl`." | |
118 | ) | |
119 | 107 | |
120 | 108 | self.default_identity_server = config.get("default_identity_server") |
121 | 109 | self.allow_guest_access = config.get("allow_guest_access", False) |
240 | 228 | # send an email to the account's email address with a renewal link. By |
241 | 229 | # default, no such emails are sent. |
242 | 230 | # |
243 | # If you enable this setting, you will also need to fill out the 'email' and | |
244 | # 'public_baseurl' configuration sections. | |
231 | # If you enable this setting, you will also need to fill out the 'email' | |
232 | # configuration section. You should also check that 'public_baseurl' is set | |
233 | # correctly. | |
245 | 234 | # |
246 | 235 | #renew_at: 1w |
247 | 236 | |
332 | 321 | # The identity server which we suggest that clients should use when users log |
333 | 322 | # in on this server. |
334 | 323 | # |
335 | # (By default, no suggestion is made, so it is left up to the client. | |
336 | # This setting is ignored unless public_baseurl is also set.) | |
324 | # (By default, no suggestion is made, so it is left up to the client.) | |
337 | 325 | # |
338 | 326 | #default_identity_server: https://matrix.org |
339 | 327 | |
357 | 345 | # Servers handling the these requests must answer the `/requestToken` endpoints defined |
358 | 346 | # by the Matrix Identity Service API specification: |
359 | 347 | # https://matrix.org/docs/spec/identity_service/latest |
360 | # | |
361 | # If a delegate is specified, the config option public_baseurl must also be filled out. | |
362 | 348 | # |
363 | 349 | account_threepid_delegates: |
364 | 350 | #email: https://example.com # Delegate email sending to example.com |
188 | 188 | import saml2 |
189 | 189 | |
190 | 190 | public_baseurl = self.public_baseurl |
191 | if public_baseurl is None: | |
192 | raise ConfigError("saml2_config requires a public_baseurl to be set") | |
193 | 191 | |
194 | 192 | if self.saml2_grandfathered_mxid_source_attribute: |
195 | 193 | optional_attributes.add(self.saml2_grandfathered_mxid_source_attribute) |
25 | 25 | from netaddr import IPSet |
26 | 26 | |
27 | 27 | from synapse.api.room_versions import KNOWN_ROOM_VERSIONS |
28 | from synapse.http.endpoint import parse_and_validate_server_name | |
28 | from synapse.util.stringutils import parse_and_validate_server_name | |
29 | 29 | |
30 | 30 | from ._base import Config, ConfigError |
31 | 31 | |
160 | 160 | self.print_pidfile = config.get("print_pidfile") |
161 | 161 | self.user_agent_suffix = config.get("user_agent_suffix") |
162 | 162 | self.use_frozen_dicts = config.get("use_frozen_dicts", False) |
163 | self.public_baseurl = config.get("public_baseurl") | |
163 | self.public_baseurl = config.get("public_baseurl") or "https://%s/" % ( | |
164 | self.server_name, | |
165 | ) | |
166 | if self.public_baseurl[-1] != "/": | |
167 | self.public_baseurl += "/" | |
164 | 168 | |
165 | 169 | # Whether to enable user presence. |
166 | 170 | self.use_presence = config.get("use_presence", True) |
316 | 320 | # Always blacklist 0.0.0.0, :: |
317 | 321 | self.federation_ip_range_blacklist.update(["0.0.0.0", "::"]) |
318 | 322 | |
319 | if self.public_baseurl is not None: | |
320 | if self.public_baseurl[-1] != "/": | |
321 | self.public_baseurl += "/" | |
322 | 323 | self.start_pushers = config.get("start_pushers", True) |
323 | 324 | |
324 | 325 | # (undocumented) option for torturing the worker-mode replication a bit, |
739 | 740 | # |
740 | 741 | #web_client_location: https://riot.example.com/ |
741 | 742 | |
742 | # The public-facing base URL that clients use to access this HS | |
743 | # (not including _matrix/...). This is the same URL a user would | |
744 | # enter into the 'custom HS URL' field on their client. If you | |
745 | # use synapse with a reverse proxy, this should be the URL to reach | |
746 | # synapse via the proxy. | |
743 | # The public-facing base URL that clients use to access this Homeserver (not | |
744 | # including _matrix/...). This is the same URL a user might enter into the | |
745 | # 'Custom Homeserver URL' field on their client. If you use Synapse with a | |
746 | # reverse proxy, this should be the URL to reach Synapse via the proxy. | |
747 | # Otherwise, it should be the URL to reach Synapse's client HTTP listener (see | |
748 | # 'listeners' below). | |
749 | # | |
750 | # If this is left unset, it defaults to 'https://<server_name>/'. (Note that | |
751 | # that will not work unless you configure Synapse or a reverse-proxy to listen | |
752 | # on port 443.) | |
747 | 753 | # |
748 | 754 | #public_baseurl: https://example.com/ |
749 | 755 |
30 | 30 | |
31 | 31 | # Read templates from disk |
32 | 32 | ( |
33 | self.sso_login_idp_picker_template, | |
33 | 34 | self.sso_redirect_confirm_template, |
34 | 35 | self.sso_auth_confirm_template, |
35 | 36 | self.sso_error_template, |
36 | 37 | sso_account_deactivated_template, |
37 | 38 | sso_auth_success_template, |
39 | self.sso_auth_bad_user_template, | |
38 | 40 | ) = self.read_templates( |
39 | 41 | [ |
42 | "sso_login_idp_picker.html", | |
40 | 43 | "sso_redirect_confirm.html", |
41 | 44 | "sso_auth_confirm.html", |
42 | 45 | "sso_error.html", |
43 | 46 | "sso_account_deactivated.html", |
44 | 47 | "sso_auth_success.html", |
48 | "sso_auth_bad_user.html", | |
45 | 49 | ], |
46 | 50 | template_dir, |
47 | 51 | ) |
59 | 63 | # gracefully to the client). This would make it pointless to ask the user for |
60 | 64 | # confirmation, since the URL the confirmation page would be showing wouldn't be |
61 | 65 | # the client's. |
62 | # public_baseurl is an optional setting, so we only add the fallback's URL to the | |
63 | # list if it's provided (because we can't figure out what that URL is otherwise). | |
64 | if self.public_baseurl: | |
65 | login_fallback_url = self.public_baseurl + "_matrix/static/client/login" | |
66 | self.sso_client_whitelist.append(login_fallback_url) | |
66 | login_fallback_url = self.public_baseurl + "_matrix/static/client/login" | |
67 | self.sso_client_whitelist.append(login_fallback_url) | |
67 | 68 | |
68 | 69 | def generate_config_section(self, **kwargs): |
69 | 70 | return """\ |
81 | 82 | # phishing attacks from evil.site. To avoid this, include a slash after the |
82 | 83 | # hostname: "https://my.client/". |
83 | 84 | # |
84 | # If public_baseurl is set, then the login fallback page (used by clients | |
85 | # that don't natively support the required login flows) is whitelisted in | |
86 | # addition to any URLs in this list. | |
85 | # The login fallback page (used by clients that don't natively support the | |
86 | # required login flows) is automatically whitelisted in addition to any URLs | |
87 | # in this list. | |
87 | 88 | # |
88 | 89 | # By default, this list is empty. |
89 | 90 | # |
96 | 97 | # directory, default templates from within the Synapse package will be used. |
97 | 98 | # |
98 | 99 | # Synapse will look for the following templates in this directory: |
100 | # | |
101 | # * HTML page to prompt the user to choose an Identity Provider during | |
102 | # login: 'sso_login_idp_picker.html'. | |
103 | # | |
104 | # This is only used if multiple SSO Identity Providers are configured. | |
105 | # | |
106 | # When rendering, this template is given the following variables: | |
107 | # * redirect_url: the URL that the user will be redirected to after | |
108 | # login. Needs manual escaping (see | |
109 | # https://jinja.palletsprojects.com/en/2.11.x/templates/#html-escaping). | |
110 | # | |
111 | # * server_name: the homeserver's name. | |
112 | # | |
113 | # * providers: a list of available Identity Providers. Each element is | |
114 | # an object with the following attributes: | |
115 | # * idp_id: unique identifier for the IdP | |
116 | # * idp_name: user-facing name for the IdP | |
117 | # | |
118 | # The rendered HTML page should contain a form which submits its results | |
119 | # back as a GET request, with the following query parameters: | |
120 | # | |
121 | # * redirectUrl: the client redirect URI (ie, the `redirect_url` passed | |
122 | # to the template) | |
123 | # | |
124 | # * idp: the 'idp_id' of the chosen IDP. | |
99 | 125 | # |
100 | 126 | # * HTML page for a confirmation step before redirecting back to the client |
101 | 127 | # with the login token: 'sso_redirect_confirm.html'. |
132 | 158 | # |
133 | 159 | # This template has no additional variables. |
134 | 160 | # |
161 | # * HTML page shown after a user-interactive authentication session which | |
162 | # does not map correctly onto the expected user: 'sso_auth_bad_user.html'. | |
163 | # | |
164 | # When rendering, this template is given the following variables: | |
165 | # * server_name: the homeserver's name. | |
166 | # * user_id_to_verify: the MXID of the user that we are trying to | |
167 | # validate. | |
168 | # | |
135 | 169 | # * HTML page shown during single sign-on if a deactivated user (according to Synapse's database) |
136 | 170 | # attempts to login: 'sso_account_deactivated.html'. |
137 | 171 | # |
52 | 52 | default=["master"], type=List[str], converter=_instance_to_list_converter |
53 | 53 | ) |
54 | 54 | typing = attr.ib(default="master", type=str) |
55 | to_device = attr.ib( | |
56 | default=["master"], type=List[str], converter=_instance_to_list_converter, | |
57 | ) | |
58 | account_data = attr.ib( | |
59 | default=["master"], type=List[str], converter=_instance_to_list_converter, | |
60 | ) | |
61 | receipts = attr.ib( | |
62 | default=["master"], type=List[str], converter=_instance_to_list_converter, | |
63 | ) | |
55 | 64 | |
56 | 65 | |
57 | 66 | class WorkerConfig(Config): |
123 | 132 | |
124 | 133 | # Check that the configured writers for events and typing also appears in |
125 | 134 | # `instance_map`. |
126 | for stream in ("events", "typing"): | |
135 | for stream in ("events", "typing", "to_device", "account_data", "receipts"): | |
127 | 136 | instances = _instance_to_list_converter(getattr(self.writers, stream)) |
128 | 137 | for instance in instances: |
129 | 138 | if instance != "master" and instance not in self.instance_map: |
132 | 141 | % (instance, stream) |
133 | 142 | ) |
134 | 143 | |
144 | if len(self.writers.to_device) != 1: | |
145 | raise ConfigError( | |
146 | "Must only specify one instance to handle `to_device` messages." | |
147 | ) | |
148 | ||
149 | if len(self.writers.account_data) != 1: | |
150 | raise ConfigError( | |
151 | "Must only specify one instance to handle `account_data` messages." | |
152 | ) | |
153 | ||
154 | if len(self.writers.receipts) != 1: | |
155 | raise ConfigError( | |
156 | "Must only specify one instance to handle `receipts` messages." | |
157 | ) | |
158 | ||
135 | 159 | self.events_shard_config = ShardedWorkerHandlingConfig(self.writers.events) |
136 | 160 | |
137 | 161 | # Whether this worker should run background tasks or not. |
16 | 16 | |
17 | 17 | import abc |
18 | 18 | import os |
19 | from distutils.util import strtobool | |
20 | 19 | from typing import Dict, Optional, Tuple, Type |
21 | 20 | |
22 | 21 | from unpaddedbase64 import encode_base64 |
25 | 24 | from synapse.types import JsonDict, RoomStreamToken |
26 | 25 | from synapse.util.caches import intern_dict |
27 | 26 | from synapse.util.frozenutils import freeze |
27 | from synapse.util.stringutils import strtobool | |
28 | 28 | |
29 | 29 | # Whether we should use frozen_dict in FrozenEvent. Using frozen_dicts prevents |
30 | 30 | # bugs where we accidentally share e.g. signature dicts. However, converting a |
33 | 33 | # NOTE: This is overridden by the configuration by the Synapse worker apps, but |
34 | 34 | # for the sake of tests, it is set here while it cannot be configured on the |
35 | 35 | # homeserver object itself. |
36 | ||
36 | 37 | USE_FROZEN_DICTS = strtobool(os.environ.get("SYNAPSE_USE_FROZEN_DICTS", "0")) |
37 | 38 | |
38 | 39 |
78 | 78 | "state_key", |
79 | 79 | "depth", |
80 | 80 | "prev_events", |
81 | "prev_state", | |
82 | 81 | "auth_events", |
83 | 82 | "origin", |
84 | 83 | "origin_server_ts", |
85 | "membership", | |
86 | 84 | ] |
85 | ||
86 | # Room versions from before MSC2176 had additional allowed keys. | |
87 | if not room_version.msc2176_redaction_rules: | |
88 | allowed_keys.extend(["prev_state", "membership"]) | |
87 | 89 | |
88 | 90 | event_type = event_dict["type"] |
89 | 91 | |
97 | 99 | if event_type == EventTypes.Member: |
98 | 100 | add_fields("membership") |
99 | 101 | elif event_type == EventTypes.Create: |
102 | # MSC2176 rules state that create events cannot be redacted. | |
103 | if room_version.msc2176_redaction_rules: | |
104 | return event_dict | |
105 | ||
100 | 106 | add_fields("creator") |
101 | 107 | elif event_type == EventTypes.JoinRules: |
102 | 108 | add_fields("join_rule") |
111 | 117 | "kick", |
112 | 118 | "redact", |
113 | 119 | ) |
120 | ||
121 | if room_version.msc2176_redaction_rules: | |
122 | add_fields("invite") | |
123 | ||
114 | 124 | elif event_type == EventTypes.Aliases and room_version.special_case_aliases_auth: |
115 | 125 | add_fields("aliases") |
116 | 126 | elif event_type == EventTypes.RoomHistoryVisibility: |
117 | 127 | add_fields("history_visibility") |
128 | elif event_type == EventTypes.Redaction and room_version.msc2176_redaction_rules: | |
129 | add_fields("redacts") | |
118 | 130 | |
119 | 131 | allowed_fields = {k: v for k, v in event_dict.items() if k in allowed_keys} |
120 | 132 |
14 | 14 | # See the License for the specific language governing permissions and |
15 | 15 | # limitations under the License. |
16 | 16 | import logging |
17 | import random | |
17 | 18 | from typing import ( |
18 | 19 | TYPE_CHECKING, |
19 | 20 | Any, |
47 | 48 | from synapse.federation.federation_base import FederationBase, event_from_pdu_json |
48 | 49 | from synapse.federation.persistence import TransactionActions |
49 | 50 | from synapse.federation.units import Edu, Transaction |
50 | from synapse.http.endpoint import parse_server_name | |
51 | 51 | from synapse.http.servlet import assert_params_in_dict |
52 | 52 | from synapse.logging.context import ( |
53 | 53 | make_deferred_yieldable, |
64 | 64 | from synapse.util import glob_to_regex, json_decoder, unwrapFirstError |
65 | 65 | from synapse.util.async_helpers import Linearizer, concurrently_execute |
66 | 66 | from synapse.util.caches.response_cache import ResponseCache |
67 | from synapse.util.stringutils import parse_server_name | |
67 | 68 | |
68 | 69 | if TYPE_CHECKING: |
69 | 70 | from synapse.server import HomeServer |
859 | 860 | ) # type: Dict[str, Callable[[str, dict], Awaitable[None]]] |
860 | 861 | self.query_handlers = {} # type: Dict[str, Callable[[dict], Awaitable[None]]] |
861 | 862 | |
862 | # Map from type to instance name that we should route EDU handling to. | |
863 | self._edu_type_to_instance = {} # type: Dict[str, str] | |
863 | # Map from type to instance names that we should route EDU handling to. | |
864 | # We randomly choose one instance from the list to route to for each new | |
865 | # EDU received. | |
866 | self._edu_type_to_instance = {} # type: Dict[str, List[str]] | |
864 | 867 | |
865 | 868 | def register_edu_handler( |
866 | 869 | self, edu_type: str, handler: Callable[[str, JsonDict], Awaitable[None]] |
904 | 907 | def register_instance_for_edu(self, edu_type: str, instance_name: str): |
905 | 908 | """Register that the EDU handler is on a different instance than master. |
906 | 909 | """ |
907 | self._edu_type_to_instance[edu_type] = instance_name | |
910 | self._edu_type_to_instance[edu_type] = [instance_name] | |
911 | ||
912 | def register_instances_for_edu(self, edu_type: str, instance_names: List[str]): | |
913 | """Register that the EDU handler is on multiple instances. | |
914 | """ | |
915 | self._edu_type_to_instance[edu_type] = instance_names | |
908 | 916 | |
909 | 917 | async def on_edu(self, edu_type: str, origin: str, content: dict): |
910 | 918 | if not self.config.use_presence and edu_type == "m.presence": |
923 | 931 | return |
924 | 932 | |
925 | 933 | # Check if we can route it somewhere else that isn't us |
926 | route_to = self._edu_type_to_instance.get(edu_type, "master") | |
927 | if route_to != self._instance_name: | |
934 | instances = self._edu_type_to_instance.get(edu_type, ["master"]) | |
935 | if self._instance_name not in instances: | |
936 | # Pick an instance randomly so that we don't overload one. | |
937 | route_to = random.choice(instances) | |
938 | ||
928 | 939 | try: |
929 | 940 | await self._send_edu( |
930 | 941 | instance_name=route_to, |
27 | 27 | FEDERATION_V1_PREFIX, |
28 | 28 | FEDERATION_V2_PREFIX, |
29 | 29 | ) |
30 | from synapse.http.endpoint import parse_and_validate_server_name | |
31 | 30 | from synapse.http.server import JsonResource |
32 | 31 | from synapse.http.servlet import ( |
33 | 32 | parse_boolean_from_args, |
44 | 43 | ) |
45 | 44 | from synapse.server import HomeServer |
46 | 45 | from synapse.types import ThirdPartyInstanceID, get_domain_from_id |
46 | from synapse.util.stringutils import parse_and_validate_server_name | |
47 | 47 | from synapse.util.versionstring import get_version_string |
48 | 48 | |
49 | 49 | logger = logging.getLogger(__name__) |
0 | 0 | # -*- coding: utf-8 -*- |
1 | 1 | # Copyright 2015, 2016 OpenMarket Ltd |
2 | # Copyright 2021 The Matrix.org Foundation C.I.C. | |
2 | 3 | # |
3 | 4 | # Licensed under the Apache License, Version 2.0 (the "License"); |
4 | 5 | # you may not use this file except in compliance with the License. |
11 | 12 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. |
12 | 13 | # See the License for the specific language governing permissions and |
13 | 14 | # limitations under the License. |
15 | import random | |
14 | 16 | from typing import TYPE_CHECKING, List, Tuple |
15 | 17 | |
18 | from synapse.replication.http.account_data import ( | |
19 | ReplicationAddTagRestServlet, | |
20 | ReplicationRemoveTagRestServlet, | |
21 | ReplicationRoomAccountDataRestServlet, | |
22 | ReplicationUserAccountDataRestServlet, | |
23 | ) | |
16 | 24 | from synapse.types import JsonDict, UserID |
17 | 25 | |
18 | 26 | if TYPE_CHECKING: |
19 | 27 | from synapse.app.homeserver import HomeServer |
28 | ||
29 | ||
30 | class AccountDataHandler: | |
31 | def __init__(self, hs: "HomeServer"): | |
32 | self._store = hs.get_datastore() | |
33 | self._instance_name = hs.get_instance_name() | |
34 | self._notifier = hs.get_notifier() | |
35 | ||
36 | self._user_data_client = ReplicationUserAccountDataRestServlet.make_client(hs) | |
37 | self._room_data_client = ReplicationRoomAccountDataRestServlet.make_client(hs) | |
38 | self._add_tag_client = ReplicationAddTagRestServlet.make_client(hs) | |
39 | self._remove_tag_client = ReplicationRemoveTagRestServlet.make_client(hs) | |
40 | self._account_data_writers = hs.config.worker.writers.account_data | |
41 | ||
42 | async def add_account_data_to_room( | |
43 | self, user_id: str, room_id: str, account_data_type: str, content: JsonDict | |
44 | ) -> int: | |
45 | """Add some account_data to a room for a user. | |
46 | ||
47 | Args: | |
48 | user_id: The user to add a tag for. | |
49 | room_id: The room to add a tag for. | |
50 | account_data_type: The type of account_data to add. | |
51 | content: A json object to associate with the tag. | |
52 | ||
53 | Returns: | |
54 | The maximum stream ID. | |
55 | """ | |
56 | if self._instance_name in self._account_data_writers: | |
57 | max_stream_id = await self._store.add_account_data_to_room( | |
58 | user_id, room_id, account_data_type, content | |
59 | ) | |
60 | ||
61 | self._notifier.on_new_event( | |
62 | "account_data_key", max_stream_id, users=[user_id] | |
63 | ) | |
64 | ||
65 | return max_stream_id | |
66 | else: | |
67 | response = await self._room_data_client( | |
68 | instance_name=random.choice(self._account_data_writers), | |
69 | user_id=user_id, | |
70 | room_id=room_id, | |
71 | account_data_type=account_data_type, | |
72 | content=content, | |
73 | ) | |
74 | return response["max_stream_id"] | |
75 | ||
76 | async def add_account_data_for_user( | |
77 | self, user_id: str, account_data_type: str, content: JsonDict | |
78 | ) -> int: | |
79 | """Add some account_data to a room for a user. | |
80 | ||
81 | Args: | |
82 | user_id: The user to add a tag for. | |
83 | account_data_type: The type of account_data to add. | |
84 | content: A json object to associate with the tag. | |
85 | ||
86 | Returns: | |
87 | The maximum stream ID. | |
88 | """ | |
89 | ||
90 | if self._instance_name in self._account_data_writers: | |
91 | max_stream_id = await self._store.add_account_data_for_user( | |
92 | user_id, account_data_type, content | |
93 | ) | |
94 | ||
95 | self._notifier.on_new_event( | |
96 | "account_data_key", max_stream_id, users=[user_id] | |
97 | ) | |
98 | return max_stream_id | |
99 | else: | |
100 | response = await self._user_data_client( | |
101 | instance_name=random.choice(self._account_data_writers), | |
102 | user_id=user_id, | |
103 | account_data_type=account_data_type, | |
104 | content=content, | |
105 | ) | |
106 | return response["max_stream_id"] | |
107 | ||
108 | async def add_tag_to_room( | |
109 | self, user_id: str, room_id: str, tag: str, content: JsonDict | |
110 | ) -> int: | |
111 | """Add a tag to a room for a user. | |
112 | ||
113 | Args: | |
114 | user_id: The user to add a tag for. | |
115 | room_id: The room to add a tag for. | |
116 | tag: The tag name to add. | |
117 | content: A json object to associate with the tag. | |
118 | ||
119 | Returns: | |
120 | The next account data ID. | |
121 | """ | |
122 | if self._instance_name in self._account_data_writers: | |
123 | max_stream_id = await self._store.add_tag_to_room( | |
124 | user_id, room_id, tag, content | |
125 | ) | |
126 | ||
127 | self._notifier.on_new_event( | |
128 | "account_data_key", max_stream_id, users=[user_id] | |
129 | ) | |
130 | return max_stream_id | |
131 | else: | |
132 | response = await self._add_tag_client( | |
133 | instance_name=random.choice(self._account_data_writers), | |
134 | user_id=user_id, | |
135 | room_id=room_id, | |
136 | tag=tag, | |
137 | content=content, | |
138 | ) | |
139 | return response["max_stream_id"] | |
140 | ||
141 | async def remove_tag_from_room(self, user_id: str, room_id: str, tag: str) -> int: | |
142 | """Remove a tag from a room for a user. | |
143 | ||
144 | Returns: | |
145 | The next account data ID. | |
146 | """ | |
147 | if self._instance_name in self._account_data_writers: | |
148 | max_stream_id = await self._store.remove_tag_from_room( | |
149 | user_id, room_id, tag | |
150 | ) | |
151 | ||
152 | self._notifier.on_new_event( | |
153 | "account_data_key", max_stream_id, users=[user_id] | |
154 | ) | |
155 | return max_stream_id | |
156 | else: | |
157 | response = await self._remove_tag_client( | |
158 | instance_name=random.choice(self._account_data_writers), | |
159 | user_id=user_id, | |
160 | room_id=room_id, | |
161 | tag=tag, | |
162 | ) | |
163 | return response["max_stream_id"] | |
20 | 164 | |
21 | 165 | |
22 | 166 | class AccountDataEventSource: |
48 | 48 | UserDeactivatedError, |
49 | 49 | ) |
50 | 50 | from synapse.api.ratelimiting import Ratelimiter |
51 | from synapse.handlers.ui_auth import INTERACTIVE_AUTH_CHECKERS | |
51 | from synapse.handlers._base import BaseHandler | |
52 | from synapse.handlers.ui_auth import ( | |
53 | INTERACTIVE_AUTH_CHECKERS, | |
54 | UIAuthSessionDataConstants, | |
55 | ) | |
52 | 56 | from synapse.handlers.ui_auth.checkers import UserInteractiveAuthChecker |
57 | from synapse.http import get_request_user_agent | |
53 | 58 | from synapse.http.server import finish_request, respond_with_html |
54 | 59 | from synapse.http.site import SynapseRequest |
55 | 60 | from synapse.logging.context import defer_to_thread |
60 | 65 | from synapse.util.async_helpers import maybe_awaitable |
61 | 66 | from synapse.util.msisdn import phone_number_to_msisdn |
62 | 67 | from synapse.util.threepids import canonicalise_email |
63 | ||
64 | from ._base import BaseHandler | |
65 | 68 | |
66 | 69 | if TYPE_CHECKING: |
67 | 70 | from synapse.app.homeserver import HomeServer |
259 | 262 | # authenticating for an operation to occur on their account. |
260 | 263 | self._sso_auth_confirm_template = hs.config.sso_auth_confirm_template |
261 | 264 | |
262 | # The following template is shown after a successful user interactive | |
263 | # authentication session. It tells the user they can close the window. | |
264 | self._sso_auth_success_template = hs.config.sso_auth_success_template | |
265 | ||
266 | 265 | # The following template is shown during the SSO authentication process if |
267 | 266 | # the account is deactivated. |
268 | 267 | self._sso_account_deactivated_template = ( |
283 | 282 | requester: Requester, |
284 | 283 | request: SynapseRequest, |
285 | 284 | request_body: Dict[str, Any], |
286 | clientip: str, | |
287 | 285 | description: str, |
288 | 286 | ) -> Tuple[dict, Optional[str]]: |
289 | 287 | """ |
299 | 297 | request: The request sent by the client. |
300 | 298 | |
301 | 299 | request_body: The body of the request sent by the client |
302 | ||
303 | clientip: The IP address of the client. | |
304 | 300 | |
305 | 301 | description: A human readable string to be displayed to the user that |
306 | 302 | describes the operation happening on their account. |
337 | 333 | request_body.pop("auth", None) |
338 | 334 | return request_body, None |
339 | 335 | |
340 | user_id = requester.user.to_string() | |
336 | requester_user_id = requester.user.to_string() | |
341 | 337 | |
342 | 338 | # Check if we should be ratelimited due to too many previous failed attempts |
343 | self._failed_uia_attempts_ratelimiter.ratelimit(user_id, update=False) | |
339 | self._failed_uia_attempts_ratelimiter.ratelimit(requester_user_id, update=False) | |
344 | 340 | |
345 | 341 | # build a list of supported flows |
346 | 342 | supported_ui_auth_types = await self._get_available_ui_auth_types( |
348 | 344 | ) |
349 | 345 | flows = [[login_type] for login_type in supported_ui_auth_types] |
350 | 346 | |
347 | def get_new_session_data() -> JsonDict: | |
348 | return {UIAuthSessionDataConstants.REQUEST_USER_ID: requester_user_id} | |
349 | ||
351 | 350 | try: |
352 | 351 | result, params, session_id = await self.check_ui_auth( |
353 | flows, request, request_body, clientip, description | |
352 | flows, request, request_body, description, get_new_session_data, | |
354 | 353 | ) |
355 | 354 | except LoginError: |
356 | 355 | # Update the ratelimiter to say we failed (`can_do_action` doesn't raise). |
357 | self._failed_uia_attempts_ratelimiter.can_do_action(user_id) | |
356 | self._failed_uia_attempts_ratelimiter.can_do_action(requester_user_id) | |
358 | 357 | raise |
359 | 358 | |
360 | 359 | # find the completed login type |
362 | 361 | if login_type not in result: |
363 | 362 | continue |
364 | 363 | |
365 | user_id = result[login_type] | |
364 | validated_user_id = result[login_type] | |
366 | 365 | break |
367 | 366 | else: |
368 | 367 | # this can't happen |
369 | 368 | raise Exception("check_auth returned True but no successful login type") |
370 | 369 | |
371 | 370 | # check that the UI auth matched the access token |
372 | if user_id != requester.user.to_string(): | |
371 | if validated_user_id != requester_user_id: | |
373 | 372 | raise AuthError(403, "Invalid auth") |
374 | 373 | |
375 | 374 | # Note that the access token has been validated. |
401 | 400 | |
402 | 401 | # if sso is enabled, allow the user to log in via SSO iff they have a mapping |
403 | 402 | # from sso to mxid. |
404 | if self.hs.config.saml2.saml2_enabled or self.hs.config.oidc.oidc_enabled: | |
405 | if await self.store.get_external_ids_by_user(user.to_string()): | |
406 | ui_auth_types.add(LoginType.SSO) | |
407 | ||
408 | # Our CAS impl does not (yet) correctly register users in user_external_ids, | |
409 | # so always offer that if it's available. | |
410 | if self.hs.config.cas.cas_enabled: | |
403 | if await self.hs.get_sso_handler().get_identity_providers_for_user( | |
404 | user.to_string() | |
405 | ): | |
411 | 406 | ui_auth_types.add(LoginType.SSO) |
412 | 407 | |
413 | 408 | return ui_auth_types |
425 | 420 | flows: List[List[str]], |
426 | 421 | request: SynapseRequest, |
427 | 422 | clientdict: Dict[str, Any], |
428 | clientip: str, | |
429 | 423 | description: str, |
424 | get_new_session_data: Optional[Callable[[], JsonDict]] = None, | |
430 | 425 | ) -> Tuple[dict, dict, str]: |
431 | 426 | """ |
432 | 427 | Takes a dictionary sent by the client in the login / registration |
447 | 442 | clientdict: The dictionary from the client root level, not the |
448 | 443 | 'auth' key: this method prompts for auth if none is sent. |
449 | 444 | |
450 | clientip: The IP address of the client. | |
451 | ||
452 | 445 | description: A human readable string to be displayed to the user that |
453 | 446 | describes the operation happening on their account. |
447 | ||
448 | get_new_session_data: | |
449 | an optional callback which will be called when starting a new session. | |
450 | it should return data to be stored as part of the session. | |
451 | ||
452 | The keys of the returned data should be entries in | |
453 | UIAuthSessionDataConstants. | |
454 | 454 | |
455 | 455 | Returns: |
456 | 456 | A tuple of (creds, params, session_id). |
479 | 479 | |
480 | 480 | # If there's no session ID, create a new session. |
481 | 481 | if not sid: |
482 | new_session_data = get_new_session_data() if get_new_session_data else {} | |
483 | ||
482 | 484 | session = await self.store.create_ui_auth_session( |
483 | 485 | clientdict, uri, method, description |
484 | 486 | ) |
487 | ||
488 | for k, v in new_session_data.items(): | |
489 | await self.set_session_data(session.session_id, k, v) | |
485 | 490 | |
486 | 491 | else: |
487 | 492 | try: |
538 | 543 | # authentication flow. |
539 | 544 | await self.store.set_ui_auth_clientdict(sid, clientdict) |
540 | 545 | |
541 | user_agent = request.get_user_agent("") | |
546 | user_agent = get_request_user_agent(request) | |
547 | clientip = request.getClientIP() | |
542 | 548 | |
543 | 549 | await self.store.add_user_agent_ip_to_ui_auth_session( |
544 | 550 | session.session_id, user_agent, clientip |
643 | 649 | |
644 | 650 | Args: |
645 | 651 | session_id: The ID of this session as returned from check_auth |
646 | key: The key to store the data under | |
652 | key: The key to store the data under. An entry from | |
653 | UIAuthSessionDataConstants. | |
647 | 654 | value: The data to store |
648 | 655 | """ |
649 | 656 | try: |
659 | 666 | |
660 | 667 | Args: |
661 | 668 | session_id: The ID of this session as returned from check_auth |
662 | key: The key to store the data under | |
669 | key: The key the data was stored under. An entry from | |
670 | UIAuthSessionDataConstants. | |
663 | 671 | default: Value to return if the key has not been set |
664 | 672 | """ |
665 | 673 | try: |
1333 | 1341 | else: |
1334 | 1342 | return False |
1335 | 1343 | |
1336 | async def start_sso_ui_auth(self, redirect_url: str, session_id: str) -> str: | |
1344 | async def start_sso_ui_auth(self, request: SynapseRequest, session_id: str) -> str: | |
1337 | 1345 | """ |
1338 | 1346 | Get the HTML for the SSO redirect confirmation page. |
1339 | 1347 | |
1340 | 1348 | Args: |
1341 | redirect_url: The URL to redirect to the SSO provider. | |
1349 | request: The incoming HTTP request | |
1342 | 1350 | session_id: The user interactive authentication session ID. |
1343 | 1351 | |
1344 | 1352 | Returns: |
1348 | 1356 | session = await self.store.get_ui_auth_session(session_id) |
1349 | 1357 | except StoreError: |
1350 | 1358 | raise SynapseError(400, "Unknown session ID: %s" % (session_id,)) |
1359 | ||
1360 | user_id_to_verify = await self.get_session_data( | |
1361 | session_id, UIAuthSessionDataConstants.REQUEST_USER_ID | |
1362 | ) # type: str | |
1363 | ||
1364 | idps = await self.hs.get_sso_handler().get_identity_providers_for_user( | |
1365 | user_id_to_verify | |
1366 | ) | |
1367 | ||
1368 | if not idps: | |
1369 | # we checked that the user had some remote identities before offering an SSO | |
1370 | # flow, so either it's been deleted or the client has requested SSO despite | |
1371 | # it not being offered. | |
1372 | raise SynapseError(400, "User has no SSO identities") | |
1373 | ||
1374 | # for now, just pick one | |
1375 | idp_id, sso_auth_provider = next(iter(idps.items())) | |
1376 | if len(idps) > 0: | |
1377 | logger.warning( | |
1378 | "User %r has previously logged in with multiple SSO IdPs; arbitrarily " | |
1379 | "picking %r", | |
1380 | user_id_to_verify, | |
1381 | idp_id, | |
1382 | ) | |
1383 | ||
1384 | redirect_url = await sso_auth_provider.handle_redirect_request( | |
1385 | request, None, session_id | |
1386 | ) | |
1387 | ||
1351 | 1388 | return self._sso_auth_confirm_template.render( |
1352 | 1389 | description=session.description, redirect_url=redirect_url, |
1353 | 1390 | ) |
1354 | ||
1355 | async def complete_sso_ui_auth( | |
1356 | self, registered_user_id: str, session_id: str, request: Request, | |
1357 | ): | |
1358 | """Having figured out a mxid for this user, complete the HTTP request | |
1359 | ||
1360 | Args: | |
1361 | registered_user_id: The registered user ID to complete SSO login for. | |
1362 | session_id: The ID of the user-interactive auth session. | |
1363 | request: The request to complete. | |
1364 | """ | |
1365 | # Mark the stage of the authentication as successful. | |
1366 | # Save the user who authenticated with SSO, this will be used to ensure | |
1367 | # that the account be modified is also the person who logged in. | |
1368 | await self.store.mark_ui_auth_stage_complete( | |
1369 | session_id, LoginType.SSO, registered_user_id | |
1370 | ) | |
1371 | ||
1372 | # Render the HTML and return. | |
1373 | html = self._sso_auth_success_template | |
1374 | respond_with_html(request, 200, html) | |
1375 | 1391 | |
1376 | 1392 | async def complete_sso_login( |
1377 | 1393 | self, |
1487 | 1503 | @staticmethod |
1488 | 1504 | def add_query_param_to_url(url: str, param_name: str, param: Any): |
1489 | 1505 | url_parts = list(urllib.parse.urlparse(url)) |
1490 | query = dict(urllib.parse.parse_qsl(url_parts[4])) | |
1491 | query.update({param_name: param}) | |
1506 | query = urllib.parse.parse_qsl(url_parts[4], keep_blank_values=True) | |
1507 | query.append((param_name, param)) | |
1492 | 1508 | url_parts[4] = urllib.parse.urlencode(query) |
1493 | 1509 | return urllib.parse.urlunparse(url_parts) |
1494 | 1510 |
74 | 74 | self._http_client = hs.get_proxied_http_client() |
75 | 75 | |
76 | 76 | # identifier for the external_ids table |
77 | self._auth_provider_id = "cas" | |
77 | self.idp_id = "cas" | |
78 | ||
79 | # user-facing name of this auth provider | |
80 | self.idp_name = "CAS" | |
81 | ||
82 | # we do not currently support icons for CAS auth, but this is required by | |
83 | # the SsoIdentityProvider protocol type. | |
84 | self.idp_icon = None | |
78 | 85 | |
79 | 86 | self._sso_handler = hs.get_sso_handler() |
87 | ||
88 | self._sso_handler.register_identity_provider(self) | |
80 | 89 | |
81 | 90 | def _build_service_param(self, args: Dict[str, str]) -> str: |
82 | 91 | """ |
104 | 113 | Args: |
105 | 114 | ticket: The CAS ticket from the client. |
106 | 115 | service_args: Additional arguments to include in the service URL. |
107 | Should be the same as those passed to `get_redirect_url`. | |
116 | Should be the same as those passed to `handle_redirect_request`. | |
108 | 117 | |
109 | 118 | Raises: |
110 | 119 | CasError: If there's an error parsing the CAS response. |
183 | 192 | |
184 | 193 | return CasResponse(user, attributes) |
185 | 194 | |
186 | def get_redirect_url(self, service_args: Dict[str, str]) -> str: | |
187 | """ | |
188 | Generates a URL for the CAS server where the client should be redirected. | |
189 | ||
190 | Args: | |
191 | service_args: Additional arguments to include in the final redirect URL. | |
195 | async def handle_redirect_request( | |
196 | self, | |
197 | request: SynapseRequest, | |
198 | client_redirect_url: Optional[bytes], | |
199 | ui_auth_session_id: Optional[str] = None, | |
200 | ) -> str: | |
201 | """Generates a URL for the CAS server where the client should be redirected. | |
202 | ||
203 | Args: | |
204 | request: the incoming HTTP request | |
205 | client_redirect_url: the URL that we should redirect the | |
206 | client to after login (or None for UI Auth). | |
207 | ui_auth_session_id: The session ID of the ongoing UI Auth (or | |
208 | None if this is a login). | |
192 | 209 | |
193 | 210 | Returns: |
194 | The URL to redirect the client to. | |
195 | """ | |
211 | URL to redirect to | |
212 | """ | |
213 | ||
214 | if ui_auth_session_id: | |
215 | service_args = {"session": ui_auth_session_id} | |
216 | else: | |
217 | assert client_redirect_url | |
218 | service_args = {"redirectUrl": client_redirect_url.decode("utf8")} | |
219 | ||
196 | 220 | args = urllib.parse.urlencode( |
197 | 221 | {"service": self._build_service_param(service_args)} |
198 | 222 | ) |
274 | 298 | # first check if we're doing a UIA |
275 | 299 | if session: |
276 | 300 | return await self._sso_handler.complete_sso_ui_auth_request( |
277 | self._auth_provider_id, cas_response.username, session, request, | |
301 | self.idp_id, cas_response.username, session, request, | |
278 | 302 | ) |
279 | 303 | |
280 | 304 | # otherwise, we're handling a login request. |
374 | 398 | return None |
375 | 399 | |
376 | 400 | await self._sso_handler.complete_sso_login_request( |
377 | self._auth_provider_id, | |
401 | self.idp_id, | |
378 | 402 | cas_response.username, |
379 | 403 | request, |
380 | 404 | client_redirect_url, |
17 | 17 | |
18 | 18 | from synapse.api.errors import SynapseError |
19 | 19 | from synapse.metrics.background_process_metrics import run_as_background_process |
20 | from synapse.types import UserID, create_requester | |
20 | from synapse.types import Requester, UserID, create_requester | |
21 | 21 | |
22 | 22 | from ._base import BaseHandler |
23 | 23 | |
37 | 37 | self._device_handler = hs.get_device_handler() |
38 | 38 | self._room_member_handler = hs.get_room_member_handler() |
39 | 39 | self._identity_handler = hs.get_identity_handler() |
40 | self._profile_handler = hs.get_profile_handler() | |
40 | 41 | self.user_directory_handler = hs.get_user_directory_handler() |
41 | 42 | self._server_name = hs.hostname |
42 | 43 | |
51 | 52 | self._account_validity_enabled = hs.config.account_validity.enabled |
52 | 53 | |
53 | 54 | async def deactivate_account( |
54 | self, user_id: str, erase_data: bool, id_server: Optional[str] = None | |
55 | self, | |
56 | user_id: str, | |
57 | erase_data: bool, | |
58 | requester: Requester, | |
59 | id_server: Optional[str] = None, | |
60 | by_admin: bool = False, | |
55 | 61 | ) -> bool: |
56 | 62 | """Deactivate a user's account |
57 | 63 | |
58 | 64 | Args: |
59 | 65 | user_id: ID of user to be deactivated |
60 | 66 | erase_data: whether to GDPR-erase the user's data |
67 | requester: The user attempting to make this change. | |
61 | 68 | id_server: Use the given identity server when unbinding |
62 | 69 | any threepids. If None then will attempt to unbind using the |
63 | 70 | identity server specified when binding (if known). |
71 | by_admin: Whether this change was made by an administrator. | |
64 | 72 | |
65 | 73 | Returns: |
66 | 74 | True if identity server supports removing threepids, otherwise False. |
120 | 128 | |
121 | 129 | # Mark the user as erased, if they asked for that |
122 | 130 | if erase_data: |
131 | user = UserID.from_string(user_id) | |
132 | # Remove avatar URL from this user | |
133 | await self._profile_handler.set_avatar_url(user, requester, "", by_admin) | |
134 | # Remove displayname from this user | |
135 | await self._profile_handler.set_displayname(user, requester, "", by_admin) | |
136 | ||
123 | 137 | logger.info("Marking %s as erased", user_id) |
124 | 138 | await self.store.mark_user_erased(user_id) |
125 | 139 |
23 | 23 | set_tag, |
24 | 24 | start_active_span, |
25 | 25 | ) |
26 | from synapse.replication.http.devices import ReplicationUserDevicesResyncRestServlet | |
26 | 27 | from synapse.types import JsonDict, UserID, get_domain_from_id |
27 | 28 | from synapse.util import json_encoder |
28 | 29 | from synapse.util.stringutils import random_string |
43 | 44 | self.store = hs.get_datastore() |
44 | 45 | self.notifier = hs.get_notifier() |
45 | 46 | self.is_mine = hs.is_mine |
46 | self.federation = hs.get_federation_sender() | |
47 | ||
48 | hs.get_federation_registry().register_edu_handler( | |
49 | "m.direct_to_device", self.on_direct_to_device_edu | |
50 | ) | |
51 | ||
52 | self._device_list_updater = hs.get_device_handler().device_list_updater | |
47 | ||
48 | # We only need to poke the federation sender explicitly if its on the | |
49 | # same instance. Other federation sender instances will get notified by | |
50 | # `synapse.app.generic_worker.FederationSenderHandler` when it sees it | |
51 | # in the to-device replication stream. | |
52 | self.federation_sender = None | |
53 | if hs.should_send_federation(): | |
54 | self.federation_sender = hs.get_federation_sender() | |
55 | ||
56 | # If we can handle the to device EDUs we do so, otherwise we route them | |
57 | # to the appropriate worker. | |
58 | if hs.get_instance_name() in hs.config.worker.writers.to_device: | |
59 | hs.get_federation_registry().register_edu_handler( | |
60 | "m.direct_to_device", self.on_direct_to_device_edu | |
61 | ) | |
62 | else: | |
63 | hs.get_federation_registry().register_instances_for_edu( | |
64 | "m.direct_to_device", hs.config.worker.writers.to_device, | |
65 | ) | |
66 | ||
67 | # The handler to call when we think a user's device list might be out of | |
68 | # sync. We do all device list resyncing on the master instance, so if | |
69 | # we're on a worker we hit the device resync replication API. | |
70 | if hs.config.worker.worker_app is None: | |
71 | self._user_device_resync = ( | |
72 | hs.get_device_handler().device_list_updater.user_device_resync | |
73 | ) | |
74 | else: | |
75 | self._user_device_resync = ReplicationUserDevicesResyncRestServlet.make_client( | |
76 | hs | |
77 | ) | |
53 | 78 | |
54 | 79 | async def on_direct_to_device_edu(self, origin: str, content: JsonDict) -> None: |
55 | 80 | local_messages = {} |
137 | 162 | await self.store.mark_remote_user_device_cache_as_stale(sender_user_id) |
138 | 163 | |
139 | 164 | # Immediately attempt a resync in the background |
140 | run_in_background( | |
141 | self._device_list_updater.user_device_resync, sender_user_id | |
142 | ) | |
165 | run_in_background(self._user_device_resync, user_id=sender_user_id) | |
143 | 166 | |
144 | 167 | async def send_device_message( |
145 | 168 | self, |
194 | 217 | ) |
195 | 218 | |
196 | 219 | log_kv({"remote_messages": remote_messages}) |
197 | for destination in remote_messages.keys(): | |
198 | # Enqueue a new federation transaction to send the new | |
199 | # device messages to each remote destination. | |
200 | self.federation.send_device_messages(destination) | |
220 | if self.federation_sender: | |
221 | for destination in remote_messages.keys(): | |
222 | # Enqueue a new federation transaction to send the new | |
223 | # device messages to each remote destination. | |
224 | self.federation_sender.send_device_messages(destination) |
474 | 474 | raise e.to_synapse_error() |
475 | 475 | except RequestTimedOutError: |
476 | 476 | raise SynapseError(500, "Timed out contacting identity server") |
477 | ||
478 | assert self.hs.config.public_baseurl | |
479 | 477 | |
480 | 478 | # we need to tell the client to send the token back to us, since it doesn't |
481 | 479 | # otherwise know where to send it, so add submit_url response parameter |
13 | 13 | # limitations under the License. |
14 | 14 | import inspect |
15 | 15 | import logging |
16 | from typing import TYPE_CHECKING, Dict, Generic, List, Optional, Tuple, TypeVar | |
16 | from typing import TYPE_CHECKING, Dict, Generic, List, Optional, TypeVar | |
17 | 17 | from urllib.parse import urlencode |
18 | 18 | |
19 | 19 | import attr |
34 | 34 | from twisted.web.client import readBody |
35 | 35 | |
36 | 36 | from synapse.config import ConfigError |
37 | from synapse.handlers._base import BaseHandler | |
37 | from synapse.config.oidc_config import OidcProviderConfig | |
38 | 38 | from synapse.handlers.sso import MappingException, UserAttributes |
39 | 39 | from synapse.http.site import SynapseRequest |
40 | 40 | from synapse.logging.context import make_deferred_yieldable |
70 | 70 | JWKS = TypedDict("JWKS", {"keys": List[JWK]}) |
71 | 71 | |
72 | 72 | |
73 | class OidcHandler: | |
74 | """Handles requests related to the OpenID Connect login flow. | |
75 | """ | |
76 | ||
77 | def __init__(self, hs: "HomeServer"): | |
78 | self._sso_handler = hs.get_sso_handler() | |
79 | ||
80 | provider_confs = hs.config.oidc.oidc_providers | |
81 | # we should not have been instantiated if there is no configured provider. | |
82 | assert provider_confs | |
83 | ||
84 | self._token_generator = OidcSessionTokenGenerator(hs) | |
85 | self._providers = { | |
86 | p.idp_id: OidcProvider(hs, self._token_generator, p) for p in provider_confs | |
87 | } # type: Dict[str, OidcProvider] | |
88 | ||
89 | async def load_metadata(self) -> None: | |
90 | """Validate the config and load the metadata from the remote endpoint. | |
91 | ||
92 | Called at startup to ensure we have everything we need. | |
93 | """ | |
94 | for idp_id, p in self._providers.items(): | |
95 | try: | |
96 | await p.load_metadata() | |
97 | await p.load_jwks() | |
98 | except Exception as e: | |
99 | raise Exception( | |
100 | "Error while initialising OIDC provider %r" % (idp_id,) | |
101 | ) from e | |
102 | ||
103 | async def handle_oidc_callback(self, request: SynapseRequest) -> None: | |
104 | """Handle an incoming request to /_synapse/oidc/callback | |
105 | ||
106 | Since we might want to display OIDC-related errors in a user-friendly | |
107 | way, we don't raise SynapseError from here. Instead, we call | |
108 | ``self._sso_handler.render_error`` which displays an HTML page for the error. | |
109 | ||
110 | Most of the OpenID Connect logic happens here: | |
111 | ||
112 | - first, we check if there was any error returned by the provider and | |
113 | display it | |
114 | - then we fetch the session cookie, decode and verify it | |
115 | - the ``state`` query parameter should match with the one stored in the | |
116 | session cookie | |
117 | ||
118 | Once we know the session is legit, we then delegate to the OIDC Provider | |
119 | implementation, which will exchange the code with the provider and complete the | |
120 | login/authentication. | |
121 | ||
122 | Args: | |
123 | request: the incoming request from the browser. | |
124 | """ | |
125 | ||
126 | # The provider might redirect with an error. | |
127 | # In that case, just display it as-is. | |
128 | if b"error" in request.args: | |
129 | # error response from the auth server. see: | |
130 | # https://tools.ietf.org/html/rfc6749#section-4.1.2.1 | |
131 | # https://openid.net/specs/openid-connect-core-1_0.html#AuthError | |
132 | error = request.args[b"error"][0].decode() | |
133 | description = request.args.get(b"error_description", [b""])[0].decode() | |
134 | ||
135 | # Most of the errors returned by the provider could be due by | |
136 | # either the provider misbehaving or Synapse being misconfigured. | |
137 | # The only exception of that is "access_denied", where the user | |
138 | # probably cancelled the login flow. In other cases, log those errors. | |
139 | if error != "access_denied": | |
140 | logger.error("Error from the OIDC provider: %s %s", error, description) | |
141 | ||
142 | self._sso_handler.render_error(request, error, description) | |
143 | return | |
144 | ||
145 | # otherwise, it is presumably a successful response. see: | |
146 | # https://tools.ietf.org/html/rfc6749#section-4.1.2 | |
147 | ||
148 | # Fetch the session cookie | |
149 | session = request.getCookie(SESSION_COOKIE_NAME) # type: Optional[bytes] | |
150 | if session is None: | |
151 | logger.info("No session cookie found") | |
152 | self._sso_handler.render_error( | |
153 | request, "missing_session", "No session cookie found" | |
154 | ) | |
155 | return | |
156 | ||
157 | # Remove the cookie. There is a good chance that if the callback failed | |
158 | # once, it will fail next time and the code will already be exchanged. | |
159 | # Removing it early avoids spamming the provider with token requests. | |
160 | request.addCookie( | |
161 | SESSION_COOKIE_NAME, | |
162 | b"", | |
163 | path="/_synapse/oidc", | |
164 | expires="Thu, Jan 01 1970 00:00:00 UTC", | |
165 | httpOnly=True, | |
166 | sameSite="lax", | |
167 | ) | |
168 | ||
169 | # Check for the state query parameter | |
170 | if b"state" not in request.args: | |
171 | logger.info("State parameter is missing") | |
172 | self._sso_handler.render_error( | |
173 | request, "invalid_request", "State parameter is missing" | |
174 | ) | |
175 | return | |
176 | ||
177 | state = request.args[b"state"][0].decode() | |
178 | ||
179 | # Deserialize the session token and verify it. | |
180 | try: | |
181 | session_data = self._token_generator.verify_oidc_session_token( | |
182 | session, state | |
183 | ) | |
184 | except (MacaroonDeserializationException, ValueError) as e: | |
185 | logger.exception("Invalid session") | |
186 | self._sso_handler.render_error(request, "invalid_session", str(e)) | |
187 | return | |
188 | except MacaroonInvalidSignatureException as e: | |
189 | logger.exception("Could not verify session") | |
190 | self._sso_handler.render_error(request, "mismatching_session", str(e)) | |
191 | return | |
192 | ||
193 | oidc_provider = self._providers.get(session_data.idp_id) | |
194 | if not oidc_provider: | |
195 | logger.error("OIDC session uses unknown IdP %r", oidc_provider) | |
196 | self._sso_handler.render_error(request, "unknown_idp", "Unknown IdP") | |
197 | return | |
198 | ||
199 | if b"code" not in request.args: | |
200 | logger.info("Code parameter is missing") | |
201 | self._sso_handler.render_error( | |
202 | request, "invalid_request", "Code parameter is missing" | |
203 | ) | |
204 | return | |
205 | ||
206 | code = request.args[b"code"][0].decode() | |
207 | ||
208 | await oidc_provider.handle_oidc_callback(request, session_data, code) | |
209 | ||
210 | ||
73 | 211 | class OidcError(Exception): |
74 | 212 | """Used to catch errors when calling the token_endpoint |
75 | 213 | """ |
84 | 222 | return self.error |
85 | 223 | |
86 | 224 | |
87 | class OidcHandler(BaseHandler): | |
88 | """Handles requests related to the OpenID Connect login flow. | |
225 | class OidcProvider: | |
226 | """Wraps the config for a single OIDC IdentityProvider | |
227 | ||
228 | Provides methods for handling redirect requests and callbacks via that particular | |
229 | IdP. | |
89 | 230 | """ |
90 | 231 | |
91 | def __init__(self, hs: "HomeServer"): | |
92 | super().__init__(hs) | |
232 | def __init__( | |
233 | self, | |
234 | hs: "HomeServer", | |
235 | token_generator: "OidcSessionTokenGenerator", | |
236 | provider: OidcProviderConfig, | |
237 | ): | |
238 | self._store = hs.get_datastore() | |
239 | ||
240 | self._token_generator = token_generator | |
241 | ||
93 | 242 | self._callback_url = hs.config.oidc_callback_url # type: str |
94 | self._scopes = hs.config.oidc_scopes # type: List[str] | |
95 | self._user_profile_method = hs.config.oidc_user_profile_method # type: str | |
243 | ||
244 | self._scopes = provider.scopes | |
245 | self._user_profile_method = provider.user_profile_method | |
96 | 246 | self._client_auth = ClientAuth( |
97 | hs.config.oidc_client_id, | |
98 | hs.config.oidc_client_secret, | |
99 | hs.config.oidc_client_auth_method, | |
247 | provider.client_id, provider.client_secret, provider.client_auth_method, | |
100 | 248 | ) # type: ClientAuth |
101 | self._client_auth_method = hs.config.oidc_client_auth_method # type: str | |
249 | self._client_auth_method = provider.client_auth_method | |
102 | 250 | self._provider_metadata = OpenIDProviderMetadata( |
103 | issuer=hs.config.oidc_issuer, | |
104 | authorization_endpoint=hs.config.oidc_authorization_endpoint, | |
105 | token_endpoint=hs.config.oidc_token_endpoint, | |
106 | userinfo_endpoint=hs.config.oidc_userinfo_endpoint, | |
107 | jwks_uri=hs.config.oidc_jwks_uri, | |
251 | issuer=provider.issuer, | |
252 | authorization_endpoint=provider.authorization_endpoint, | |
253 | token_endpoint=provider.token_endpoint, | |
254 | userinfo_endpoint=provider.userinfo_endpoint, | |
255 | jwks_uri=provider.jwks_uri, | |
108 | 256 | ) # type: OpenIDProviderMetadata |
109 | self._provider_needs_discovery = hs.config.oidc_discover # type: bool | |
110 | self._user_mapping_provider = hs.config.oidc_user_mapping_provider_class( | |
111 | hs.config.oidc_user_mapping_provider_config | |
112 | ) # type: OidcMappingProvider | |
113 | self._skip_verification = hs.config.oidc_skip_verification # type: bool | |
114 | self._allow_existing_users = hs.config.oidc_allow_existing_users # type: bool | |
257 | self._provider_needs_discovery = provider.discover | |
258 | self._user_mapping_provider = provider.user_mapping_provider_class( | |
259 | provider.user_mapping_provider_config | |
260 | ) | |
261 | self._skip_verification = provider.skip_verification | |
262 | self._allow_existing_users = provider.allow_existing_users | |
115 | 263 | |
116 | 264 | self._http_client = hs.get_proxied_http_client() |
117 | 265 | self._server_name = hs.config.server_name # type: str |
118 | self._macaroon_secret_key = hs.config.macaroon_secret_key | |
119 | 266 | |
120 | 267 | # identifier for the external_ids table |
121 | self._auth_provider_id = "oidc" | |
268 | self.idp_id = provider.idp_id | |
269 | ||
270 | # user-facing name of this auth provider | |
271 | self.idp_name = provider.idp_name | |
272 | ||
273 | # MXC URI for icon for this auth provider | |
274 | self.idp_icon = provider.idp_icon | |
122 | 275 | |
123 | 276 | self._sso_handler = hs.get_sso_handler() |
277 | ||
278 | self._sso_handler.register_identity_provider(self) | |
124 | 279 | |
125 | 280 | def _validate_metadata(self): |
126 | 281 | """Verifies the provider metadata. |
474 | 629 | async def handle_redirect_request( |
475 | 630 | self, |
476 | 631 | request: SynapseRequest, |
477 | client_redirect_url: bytes, | |
632 | client_redirect_url: Optional[bytes], | |
478 | 633 | ui_auth_session_id: Optional[str] = None, |
479 | 634 | ) -> str: |
480 | 635 | """Handle an incoming request to /login/sso/redirect |
498 | 653 | request: the incoming request from the browser. |
499 | 654 | We'll respond to it with a redirect and a cookie. |
500 | 655 | client_redirect_url: the URL that we should redirect the client to |
501 | when everything is done | |
656 | when everything is done (or None for UI Auth) | |
502 | 657 | ui_auth_session_id: The session ID of the ongoing UI Auth (or |
503 | 658 | None if this is a login). |
504 | 659 | |
510 | 665 | state = generate_token() |
511 | 666 | nonce = generate_token() |
512 | 667 | |
513 | cookie = self._generate_oidc_session_token( | |
668 | if not client_redirect_url: | |
669 | client_redirect_url = b"" | |
670 | ||
671 | cookie = self._token_generator.generate_oidc_session_token( | |
514 | 672 | state=state, |
515 | nonce=nonce, | |
516 | client_redirect_url=client_redirect_url.decode(), | |
517 | ui_auth_session_id=ui_auth_session_id, | |
673 | session_data=OidcSessionData( | |
674 | idp_id=self.idp_id, | |
675 | nonce=nonce, | |
676 | client_redirect_url=client_redirect_url.decode(), | |
677 | ui_auth_session_id=ui_auth_session_id, | |
678 | ), | |
518 | 679 | ) |
519 | 680 | request.addCookie( |
520 | 681 | SESSION_COOKIE_NAME, |
537 | 698 | nonce=nonce, |
538 | 699 | ) |
539 | 700 | |
540 | async def handle_oidc_callback(self, request: SynapseRequest) -> None: | |
701 | async def handle_oidc_callback( | |
702 | self, request: SynapseRequest, session_data: "OidcSessionData", code: str | |
703 | ) -> None: | |
541 | 704 | """Handle an incoming request to /_synapse/oidc/callback |
542 | 705 | |
543 | Since we might want to display OIDC-related errors in a user-friendly | |
544 | way, we don't raise SynapseError from here. Instead, we call | |
545 | ``self._sso_handler.render_error`` which displays an HTML page for the error. | |
546 | ||
547 | Most of the OpenID Connect logic happens here: | |
548 | ||
549 | - first, we check if there was any error returned by the provider and | |
550 | display it | |
551 | - then we fetch the session cookie, decode and verify it | |
552 | - the ``state`` query parameter should match with the one stored in the | |
553 | session cookie | |
554 | - once we known this session is legit, exchange the code with the | |
555 | provider using the ``token_endpoint`` (see ``_exchange_code``) | |
706 | By this time we have already validated the session on the synapse side, and | |
707 | now need to do the provider-specific operations. This includes: | |
708 | ||
709 | - exchange the code with the provider using the ``token_endpoint`` (see | |
710 | ``_exchange_code``) | |
556 | 711 | - once we have the token, use it to either extract the UserInfo from |
557 | 712 | the ``id_token`` (``_parse_id_token``), or use the ``access_token`` |
558 | 713 | to fetch UserInfo from the ``userinfo_endpoint`` |
562 | 717 | |
563 | 718 | Args: |
564 | 719 | request: the incoming request from the browser. |
565 | """ | |
566 | ||
567 | # The provider might redirect with an error. | |
568 | # In that case, just display it as-is. | |
569 | if b"error" in request.args: | |
570 | # error response from the auth server. see: | |
571 | # https://tools.ietf.org/html/rfc6749#section-4.1.2.1 | |
572 | # https://openid.net/specs/openid-connect-core-1_0.html#AuthError | |
573 | error = request.args[b"error"][0].decode() | |
574 | description = request.args.get(b"error_description", [b""])[0].decode() | |
575 | ||
576 | # Most of the errors returned by the provider could be due by | |
577 | # either the provider misbehaving or Synapse being misconfigured. | |
578 | # The only exception of that is "access_denied", where the user | |
579 | # probably cancelled the login flow. In other cases, log those errors. | |
580 | if error != "access_denied": | |
581 | logger.error("Error from the OIDC provider: %s %s", error, description) | |
582 | ||
583 | self._sso_handler.render_error(request, error, description) | |
584 | return | |
585 | ||
586 | # otherwise, it is presumably a successful response. see: | |
587 | # https://tools.ietf.org/html/rfc6749#section-4.1.2 | |
588 | ||
589 | # Fetch the session cookie | |
590 | session = request.getCookie(SESSION_COOKIE_NAME) # type: Optional[bytes] | |
591 | if session is None: | |
592 | logger.info("No session cookie found") | |
593 | self._sso_handler.render_error( | |
594 | request, "missing_session", "No session cookie found" | |
595 | ) | |
596 | return | |
597 | ||
598 | # Remove the cookie. There is a good chance that if the callback failed | |
599 | # once, it will fail next time and the code will already be exchanged. | |
600 | # Removing it early avoids spamming the provider with token requests. | |
601 | request.addCookie( | |
602 | SESSION_COOKIE_NAME, | |
603 | b"", | |
604 | path="/_synapse/oidc", | |
605 | expires="Thu, Jan 01 1970 00:00:00 UTC", | |
606 | httpOnly=True, | |
607 | sameSite="lax", | |
608 | ) | |
609 | ||
610 | # Check for the state query parameter | |
611 | if b"state" not in request.args: | |
612 | logger.info("State parameter is missing") | |
613 | self._sso_handler.render_error( | |
614 | request, "invalid_request", "State parameter is missing" | |
615 | ) | |
616 | return | |
617 | ||
618 | state = request.args[b"state"][0].decode() | |
619 | ||
620 | # Deserialize the session token and verify it. | |
720 | session_data: the session data, extracted from our cookie | |
721 | code: The authorization code we got from the callback. | |
722 | """ | |
723 | # Exchange the code with the provider | |
621 | 724 | try: |
622 | ( | |
623 | nonce, | |
624 | client_redirect_url, | |
625 | ui_auth_session_id, | |
626 | ) = self._verify_oidc_session_token(session, state) | |
627 | except MacaroonDeserializationException as e: | |
628 | logger.exception("Invalid session") | |
629 | self._sso_handler.render_error(request, "invalid_session", str(e)) | |
630 | return | |
631 | except MacaroonInvalidSignatureException as e: | |
632 | logger.exception("Could not verify session") | |
633 | self._sso_handler.render_error(request, "mismatching_session", str(e)) | |
634 | return | |
635 | ||
636 | # Exchange the code with the provider | |
637 | if b"code" not in request.args: | |
638 | logger.info("Code parameter is missing") | |
639 | self._sso_handler.render_error( | |
640 | request, "invalid_request", "Code parameter is missing" | |
641 | ) | |
642 | return | |
643 | ||
644 | logger.debug("Exchanging code") | |
645 | code = request.args[b"code"][0].decode() | |
646 | try: | |
725 | logger.debug("Exchanging code") | |
647 | 726 | token = await self._exchange_code(code) |
648 | 727 | except OidcError as e: |
649 | 728 | logger.exception("Could not exchange code") |
665 | 744 | else: |
666 | 745 | logger.debug("Extracting userinfo from id_token") |
667 | 746 | try: |
668 | userinfo = await self._parse_id_token(token, nonce=nonce) | |
747 | userinfo = await self._parse_id_token(token, nonce=session_data.nonce) | |
669 | 748 | except Exception as e: |
670 | 749 | logger.exception("Invalid id_token") |
671 | 750 | self._sso_handler.render_error(request, "invalid_token", str(e)) |
672 | 751 | return |
673 | 752 | |
674 | 753 | # first check if we're doing a UIA |
675 | if ui_auth_session_id: | |
754 | if session_data.ui_auth_session_id: | |
676 | 755 | try: |
677 | 756 | remote_user_id = self._remote_id_from_userinfo(userinfo) |
678 | 757 | except Exception as e: |
681 | 760 | return |
682 | 761 | |
683 | 762 | return await self._sso_handler.complete_sso_ui_auth_request( |
684 | self._auth_provider_id, remote_user_id, ui_auth_session_id, request | |
763 | self.idp_id, remote_user_id, session_data.ui_auth_session_id, request | |
685 | 764 | ) |
686 | 765 | |
687 | 766 | # otherwise, it's a login |
689 | 768 | # Call the mapper to register/login the user |
690 | 769 | try: |
691 | 770 | await self._complete_oidc_login( |
692 | userinfo, token, request, client_redirect_url | |
771 | userinfo, token, request, session_data.client_redirect_url | |
693 | 772 | ) |
694 | 773 | except MappingException as e: |
695 | 774 | logger.exception("Could not map user") |
696 | 775 | self._sso_handler.render_error(request, "mapping_error", str(e)) |
697 | ||
698 | def _generate_oidc_session_token( | |
699 | self, | |
700 | state: str, | |
701 | nonce: str, | |
702 | client_redirect_url: str, | |
703 | ui_auth_session_id: Optional[str], | |
704 | duration_in_ms: int = (60 * 60 * 1000), | |
705 | ) -> str: | |
706 | """Generates a signed token storing data about an OIDC session. | |
707 | ||
708 | When Synapse initiates an authorization flow, it creates a random state | |
709 | and a random nonce. Those parameters are given to the provider and | |
710 | should be verified when the client comes back from the provider. | |
711 | It is also used to store the client_redirect_url, which is used to | |
712 | complete the SSO login flow. | |
713 | ||
714 | Args: | |
715 | state: The ``state`` parameter passed to the OIDC provider. | |
716 | nonce: The ``nonce`` parameter passed to the OIDC provider. | |
717 | client_redirect_url: The URL the client gave when it initiated the | |
718 | flow. | |
719 | ui_auth_session_id: The session ID of the ongoing UI Auth (or | |
720 | None if this is a login). | |
721 | duration_in_ms: An optional duration for the token in milliseconds. | |
722 | Defaults to an hour. | |
723 | ||
724 | Returns: | |
725 | A signed macaroon token with the session information. | |
726 | """ | |
727 | macaroon = pymacaroons.Macaroon( | |
728 | location=self._server_name, identifier="key", key=self._macaroon_secret_key, | |
729 | ) | |
730 | macaroon.add_first_party_caveat("gen = 1") | |
731 | macaroon.add_first_party_caveat("type = session") | |
732 | macaroon.add_first_party_caveat("state = %s" % (state,)) | |
733 | macaroon.add_first_party_caveat("nonce = %s" % (nonce,)) | |
734 | macaroon.add_first_party_caveat( | |
735 | "client_redirect_url = %s" % (client_redirect_url,) | |
736 | ) | |
737 | if ui_auth_session_id: | |
738 | macaroon.add_first_party_caveat( | |
739 | "ui_auth_session_id = %s" % (ui_auth_session_id,) | |
740 | ) | |
741 | now = self.clock.time_msec() | |
742 | expiry = now + duration_in_ms | |
743 | macaroon.add_first_party_caveat("time < %d" % (expiry,)) | |
744 | ||
745 | return macaroon.serialize() | |
746 | ||
747 | def _verify_oidc_session_token( | |
748 | self, session: bytes, state: str | |
749 | ) -> Tuple[str, str, Optional[str]]: | |
750 | """Verifies and extract an OIDC session token. | |
751 | ||
752 | This verifies that a given session token was issued by this homeserver | |
753 | and extract the nonce and client_redirect_url caveats. | |
754 | ||
755 | Args: | |
756 | session: The session token to verify | |
757 | state: The state the OIDC provider gave back | |
758 | ||
759 | Returns: | |
760 | The nonce, client_redirect_url, and ui_auth_session_id for this session | |
761 | """ | |
762 | macaroon = pymacaroons.Macaroon.deserialize(session) | |
763 | ||
764 | v = pymacaroons.Verifier() | |
765 | v.satisfy_exact("gen = 1") | |
766 | v.satisfy_exact("type = session") | |
767 | v.satisfy_exact("state = %s" % (state,)) | |
768 | v.satisfy_general(lambda c: c.startswith("nonce = ")) | |
769 | v.satisfy_general(lambda c: c.startswith("client_redirect_url = ")) | |
770 | # Sometimes there's a UI auth session ID, it seems to be OK to attempt | |
771 | # to always satisfy this. | |
772 | v.satisfy_general(lambda c: c.startswith("ui_auth_session_id = ")) | |
773 | v.satisfy_general(self._verify_expiry) | |
774 | ||
775 | v.verify(macaroon, self._macaroon_secret_key) | |
776 | ||
777 | # Extract the `nonce`, `client_redirect_url`, and maybe the | |
778 | # `ui_auth_session_id` from the token. | |
779 | nonce = self._get_value_from_macaroon(macaroon, "nonce") | |
780 | client_redirect_url = self._get_value_from_macaroon( | |
781 | macaroon, "client_redirect_url" | |
782 | ) | |
783 | try: | |
784 | ui_auth_session_id = self._get_value_from_macaroon( | |
785 | macaroon, "ui_auth_session_id" | |
786 | ) # type: Optional[str] | |
787 | except ValueError: | |
788 | ui_auth_session_id = None | |
789 | ||
790 | return nonce, client_redirect_url, ui_auth_session_id | |
791 | ||
792 | def _get_value_from_macaroon(self, macaroon: pymacaroons.Macaroon, key: str) -> str: | |
793 | """Extracts a caveat value from a macaroon token. | |
794 | ||
795 | Args: | |
796 | macaroon: the token | |
797 | key: the key of the caveat to extract | |
798 | ||
799 | Returns: | |
800 | The extracted value | |
801 | ||
802 | Raises: | |
803 | Exception: if the caveat was not in the macaroon | |
804 | """ | |
805 | prefix = key + " = " | |
806 | for caveat in macaroon.caveats: | |
807 | if caveat.caveat_id.startswith(prefix): | |
808 | return caveat.caveat_id[len(prefix) :] | |
809 | raise ValueError("No %s caveat in macaroon" % (key,)) | |
810 | ||
811 | def _verify_expiry(self, caveat: str) -> bool: | |
812 | prefix = "time < " | |
813 | if not caveat.startswith(prefix): | |
814 | return False | |
815 | expiry = int(caveat[len(prefix) :]) | |
816 | now = self.clock.time_msec() | |
817 | return now < expiry | |
818 | 776 | |
819 | 777 | async def _complete_oidc_login( |
820 | 778 | self, |
892 | 850 | # and attempt to match it. |
893 | 851 | attributes = await oidc_response_to_user_attributes(failures=0) |
894 | 852 | |
895 | user_id = UserID(attributes.localpart, self.server_name).to_string() | |
896 | users = await self.store.get_users_by_id_case_insensitive(user_id) | |
853 | user_id = UserID(attributes.localpart, self._server_name).to_string() | |
854 | users = await self._store.get_users_by_id_case_insensitive(user_id) | |
897 | 855 | if users: |
898 | 856 | # If an existing matrix ID is returned, then use it. |
899 | 857 | if len(users) == 1: |
922 | 880 | extra_attributes = await get_extra_attributes(userinfo, token) |
923 | 881 | |
924 | 882 | await self._sso_handler.complete_sso_login_request( |
925 | self._auth_provider_id, | |
883 | self.idp_id, | |
926 | 884 | remote_user_id, |
927 | 885 | request, |
928 | 886 | client_redirect_url, |
943 | 901 | # Some OIDC providers use integer IDs, but Synapse expects external IDs |
944 | 902 | # to be strings. |
945 | 903 | return str(remote_user_id) |
904 | ||
905 | ||
906 | class OidcSessionTokenGenerator: | |
907 | """Methods for generating and checking OIDC Session cookies.""" | |
908 | ||
909 | def __init__(self, hs: "HomeServer"): | |
910 | self._clock = hs.get_clock() | |
911 | self._server_name = hs.hostname | |
912 | self._macaroon_secret_key = hs.config.key.macaroon_secret_key | |
913 | ||
914 | def generate_oidc_session_token( | |
915 | self, | |
916 | state: str, | |
917 | session_data: "OidcSessionData", | |
918 | duration_in_ms: int = (60 * 60 * 1000), | |
919 | ) -> str: | |
920 | """Generates a signed token storing data about an OIDC session. | |
921 | ||
922 | When Synapse initiates an authorization flow, it creates a random state | |
923 | and a random nonce. Those parameters are given to the provider and | |
924 | should be verified when the client comes back from the provider. | |
925 | It is also used to store the client_redirect_url, which is used to | |
926 | complete the SSO login flow. | |
927 | ||
928 | Args: | |
929 | state: The ``state`` parameter passed to the OIDC provider. | |
930 | session_data: data to include in the session token. | |
931 | duration_in_ms: An optional duration for the token in milliseconds. | |
932 | Defaults to an hour. | |
933 | ||
934 | Returns: | |
935 | A signed macaroon token with the session information. | |
936 | """ | |
937 | macaroon = pymacaroons.Macaroon( | |
938 | location=self._server_name, identifier="key", key=self._macaroon_secret_key, | |
939 | ) | |
940 | macaroon.add_first_party_caveat("gen = 1") | |
941 | macaroon.add_first_party_caveat("type = session") | |
942 | macaroon.add_first_party_caveat("state = %s" % (state,)) | |
943 | macaroon.add_first_party_caveat("idp_id = %s" % (session_data.idp_id,)) | |
944 | macaroon.add_first_party_caveat("nonce = %s" % (session_data.nonce,)) | |
945 | macaroon.add_first_party_caveat( | |
946 | "client_redirect_url = %s" % (session_data.client_redirect_url,) | |
947 | ) | |
948 | if session_data.ui_auth_session_id: | |
949 | macaroon.add_first_party_caveat( | |
950 | "ui_auth_session_id = %s" % (session_data.ui_auth_session_id,) | |
951 | ) | |
952 | now = self._clock.time_msec() | |
953 | expiry = now + duration_in_ms | |
954 | macaroon.add_first_party_caveat("time < %d" % (expiry,)) | |
955 | ||
956 | return macaroon.serialize() | |
957 | ||
958 | def verify_oidc_session_token( | |
959 | self, session: bytes, state: str | |
960 | ) -> "OidcSessionData": | |
961 | """Verifies and extract an OIDC session token. | |
962 | ||
963 | This verifies that a given session token was issued by this homeserver | |
964 | and extract the nonce and client_redirect_url caveats. | |
965 | ||
966 | Args: | |
967 | session: The session token to verify | |
968 | state: The state the OIDC provider gave back | |
969 | ||
970 | Returns: | |
971 | The data extracted from the session cookie | |
972 | ||
973 | Raises: | |
974 | ValueError if an expected caveat is missing from the macaroon. | |
975 | """ | |
976 | macaroon = pymacaroons.Macaroon.deserialize(session) | |
977 | ||
978 | v = pymacaroons.Verifier() | |
979 | v.satisfy_exact("gen = 1") | |
980 | v.satisfy_exact("type = session") | |
981 | v.satisfy_exact("state = %s" % (state,)) | |
982 | v.satisfy_general(lambda c: c.startswith("nonce = ")) | |
983 | v.satisfy_general(lambda c: c.startswith("idp_id = ")) | |
984 | v.satisfy_general(lambda c: c.startswith("client_redirect_url = ")) | |
985 | # Sometimes there's a UI auth session ID, it seems to be OK to attempt | |
986 | # to always satisfy this. | |
987 | v.satisfy_general(lambda c: c.startswith("ui_auth_session_id = ")) | |
988 | v.satisfy_general(self._verify_expiry) | |
989 | ||
990 | v.verify(macaroon, self._macaroon_secret_key) | |
991 | ||
992 | # Extract the session data from the token. | |
993 | nonce = self._get_value_from_macaroon(macaroon, "nonce") | |
994 | idp_id = self._get_value_from_macaroon(macaroon, "idp_id") | |
995 | client_redirect_url = self._get_value_from_macaroon( | |
996 | macaroon, "client_redirect_url" | |
997 | ) | |
998 | try: | |
999 | ui_auth_session_id = self._get_value_from_macaroon( | |
1000 | macaroon, "ui_auth_session_id" | |
1001 | ) # type: Optional[str] | |
1002 | except ValueError: | |
1003 | ui_auth_session_id = None | |
1004 | ||
1005 | return OidcSessionData( | |
1006 | nonce=nonce, | |
1007 | idp_id=idp_id, | |
1008 | client_redirect_url=client_redirect_url, | |
1009 | ui_auth_session_id=ui_auth_session_id, | |
1010 | ) | |
1011 | ||
1012 | def _get_value_from_macaroon(self, macaroon: pymacaroons.Macaroon, key: str) -> str: | |
1013 | """Extracts a caveat value from a macaroon token. | |
1014 | ||
1015 | Args: | |
1016 | macaroon: the token | |
1017 | key: the key of the caveat to extract | |
1018 | ||
1019 | Returns: | |
1020 | The extracted value | |
1021 | ||
1022 | Raises: | |
1023 | ValueError: if the caveat was not in the macaroon | |
1024 | """ | |
1025 | prefix = key + " = " | |
1026 | for caveat in macaroon.caveats: | |
1027 | if caveat.caveat_id.startswith(prefix): | |
1028 | return caveat.caveat_id[len(prefix) :] | |
1029 | raise ValueError("No %s caveat in macaroon" % (key,)) | |
1030 | ||
1031 | def _verify_expiry(self, caveat: str) -> bool: | |
1032 | prefix = "time < " | |
1033 | if not caveat.startswith(prefix): | |
1034 | return False | |
1035 | expiry = int(caveat[len(prefix) :]) | |
1036 | now = self._clock.time_msec() | |
1037 | return now < expiry | |
1038 | ||
1039 | ||
1040 | @attr.s(frozen=True, slots=True) | |
1041 | class OidcSessionData: | |
1042 | """The attributes which are stored in a OIDC session cookie""" | |
1043 | ||
1044 | # the Identity Provider being used | |
1045 | idp_id = attr.ib(type=str) | |
1046 | ||
1047 | # The `nonce` parameter passed to the OIDC provider. | |
1048 | nonce = attr.ib(type=str) | |
1049 | ||
1050 | # The URL the client gave when it initiated the flow. ("" if this is a UI Auth) | |
1051 | client_redirect_url = attr.ib(type=str) | |
1052 | ||
1053 | # The session ID of the ongoing UI Auth (None if this is a login) | |
1054 | ui_auth_session_id = attr.ib(type=Optional[str], default=None) | |
946 | 1055 | |
947 | 1056 | |
948 | 1057 | UserAttributeDict = TypedDict( |
155 | 155 | except HttpResponseException as e: |
156 | 156 | raise e.to_synapse_error() |
157 | 157 | |
158 | return result["displayname"] | |
158 | return result.get("displayname") | |
159 | 159 | |
160 | 160 | async def set_displayname( |
161 | 161 | self, |
245 | 245 | except HttpResponseException as e: |
246 | 246 | raise e.to_synapse_error() |
247 | 247 | |
248 | return result["avatar_url"] | |
248 | return result.get("avatar_url") | |
249 | 249 | |
250 | 250 | async def set_avatar_url( |
251 | 251 | self, |
285 | 285 | 400, "Avatar URL is too long (max %i)" % (MAX_AVATAR_URL_LEN,) |
286 | 286 | ) |
287 | 287 | |
288 | avatar_url_to_set = new_avatar_url # type: Optional[str] | |
289 | if new_avatar_url == "": | |
290 | avatar_url_to_set = None | |
291 | ||
288 | 292 | # Same like set_displayname |
289 | 293 | if by_admin: |
290 | 294 | requester = create_requester( |
291 | 295 | target_user, authenticated_entity=requester.authenticated_entity |
292 | 296 | ) |
293 | 297 | |
294 | await self.store.set_profile_avatar_url(target_user.localpart, new_avatar_url) | |
298 | await self.store.set_profile_avatar_url( | |
299 | target_user.localpart, avatar_url_to_set | |
300 | ) | |
295 | 301 | |
296 | 302 | if self.hs.config.user_directory_search_all_users: |
297 | 303 | profile = await self.store.get_profileinfo(target_user.localpart) |
30 | 30 | super().__init__(hs) |
31 | 31 | self.server_name = hs.config.server_name |
32 | 32 | self.store = hs.get_datastore() |
33 | self.account_data_handler = hs.get_account_data_handler() | |
33 | 34 | self.read_marker_linearizer = Linearizer(name="read_marker") |
34 | self.notifier = hs.get_notifier() | |
35 | 35 | |
36 | 36 | async def received_client_read_marker( |
37 | 37 | self, room_id: str, user_id: str, event_id: str |
58 | 58 | |
59 | 59 | if should_update: |
60 | 60 | content = {"event_id": event_id} |
61 | max_id = await self.store.add_account_data_to_room( | |
61 | await self.account_data_handler.add_account_data_to_room( | |
62 | 62 | user_id, room_id, "m.fully_read", content |
63 | 63 | ) |
64 | self.notifier.on_new_event("account_data_key", max_id, users=[user_id]) |
31 | 31 | self.server_name = hs.config.server_name |
32 | 32 | self.store = hs.get_datastore() |
33 | 33 | self.hs = hs |
34 | self.federation = hs.get_federation_sender() | |
35 | hs.get_federation_registry().register_edu_handler( | |
36 | "m.receipt", self._received_remote_receipt | |
37 | ) | |
34 | ||
35 | # We only need to poke the federation sender explicitly if its on the | |
36 | # same instance. Other federation sender instances will get notified by | |
37 | # `synapse.app.generic_worker.FederationSenderHandler` when it sees it | |
38 | # in the receipts stream. | |
39 | self.federation_sender = None | |
40 | if hs.should_send_federation(): | |
41 | self.federation_sender = hs.get_federation_sender() | |
42 | ||
43 | # If we can handle the receipt EDUs we do so, otherwise we route them | |
44 | # to the appropriate worker. | |
45 | if hs.get_instance_name() in hs.config.worker.writers.receipts: | |
46 | hs.get_federation_registry().register_edu_handler( | |
47 | "m.receipt", self._received_remote_receipt | |
48 | ) | |
49 | else: | |
50 | hs.get_federation_registry().register_instances_for_edu( | |
51 | "m.receipt", hs.config.worker.writers.receipts, | |
52 | ) | |
53 | ||
38 | 54 | self.clock = self.hs.get_clock() |
39 | 55 | self.state = hs.get_state_handler() |
40 | 56 | |
124 | 140 | if not is_new: |
125 | 141 | return |
126 | 142 | |
127 | await self.federation.send_read_receipt(receipt) | |
143 | if self.federation_sender: | |
144 | await self.federation_sender.send_read_receipt(receipt) | |
128 | 145 | |
129 | 146 | |
130 | 147 | class ReceiptEventSource: |
37 | 37 | from synapse.api.room_versions import KNOWN_ROOM_VERSIONS, RoomVersion |
38 | 38 | from synapse.events import EventBase |
39 | 39 | from synapse.events.utils import copy_power_levels_contents |
40 | from synapse.http.endpoint import parse_and_validate_server_name | |
41 | 40 | from synapse.storage.state import StateFilter |
42 | 41 | from synapse.types import ( |
43 | 42 | JsonDict, |
54 | 53 | from synapse.util import stringutils |
55 | 54 | from synapse.util.async_helpers import Linearizer |
56 | 55 | from synapse.util.caches.response_cache import ResponseCache |
56 | from synapse.util.stringutils import parse_and_validate_server_name | |
57 | 57 | from synapse.visibility import filter_events_for_client |
58 | 58 | |
59 | 59 | from ._base import BaseHandler |
364 | 364 | creation_content = { |
365 | 365 | "room_version": new_room_version.identifier, |
366 | 366 | "predecessor": {"room_id": old_room_id, "event_id": tombstone_event_id}, |
367 | } | |
367 | } # type: JsonDict | |
368 | 368 | |
369 | 369 | # Check if old room was non-federatable |
370 | 370 |
62 | 62 | self.registration_handler = hs.get_registration_handler() |
63 | 63 | self.profile_handler = hs.get_profile_handler() |
64 | 64 | self.event_creation_handler = hs.get_event_creation_handler() |
65 | self.account_data_handler = hs.get_account_data_handler() | |
65 | 66 | |
66 | 67 | self.member_linearizer = Linearizer(name="member") |
67 | 68 | |
252 | 253 | direct_rooms[key].append(new_room_id) |
253 | 254 | |
254 | 255 | # Save back to user's m.direct account data |
255 | await self.store.add_account_data_for_user( | |
256 | await self.account_data_handler.add_account_data_for_user( | |
256 | 257 | user_id, AccountDataTypes.DIRECT, direct_rooms |
257 | 258 | ) |
258 | 259 | break |
262 | 263 | |
263 | 264 | # Copy each room tag to the new room |
264 | 265 | for tag, tag_content in room_tags.items(): |
265 | await self.store.add_tag_to_room(user_id, new_room_id, tag, tag_content) | |
266 | await self.account_data_handler.add_tag_to_room( | |
267 | user_id, new_room_id, tag, tag_content | |
268 | ) | |
266 | 269 | |
267 | 270 | async def update_membership( |
268 | 271 | self, |
72 | 72 | ) |
73 | 73 | |
74 | 74 | # identifier for the external_ids table |
75 | self._auth_provider_id = "saml" | |
75 | self.idp_id = "saml" | |
76 | ||
77 | # user-facing name of this auth provider | |
78 | self.idp_name = "SAML" | |
79 | ||
80 | # we do not currently support icons for SAML auth, but this is required by | |
81 | # the SsoIdentityProvider protocol type. | |
82 | self.idp_icon = None | |
76 | 83 | |
77 | 84 | # a map from saml session id to Saml2SessionData object |
78 | 85 | self._outstanding_requests_dict = {} # type: Dict[str, Saml2SessionData] |
79 | 86 | |
80 | 87 | self._sso_handler = hs.get_sso_handler() |
81 | ||
82 | def handle_redirect_request( | |
83 | self, client_redirect_url: bytes, ui_auth_session_id: Optional[str] = None | |
84 | ) -> bytes: | |
88 | self._sso_handler.register_identity_provider(self) | |
89 | ||
90 | async def handle_redirect_request( | |
91 | self, | |
92 | request: SynapseRequest, | |
93 | client_redirect_url: Optional[bytes], | |
94 | ui_auth_session_id: Optional[str] = None, | |
95 | ) -> str: | |
85 | 96 | """Handle an incoming request to /login/sso/redirect |
86 | 97 | |
87 | 98 | Args: |
99 | request: the incoming HTTP request | |
88 | 100 | client_redirect_url: the URL that we should redirect the |
89 | client to when everything is done | |
101 | client to after login (or None for UI Auth). | |
90 | 102 | ui_auth_session_id: The session ID of the ongoing UI Auth (or |
91 | 103 | None if this is a login). |
92 | 104 | |
93 | 105 | Returns: |
94 | 106 | URL to redirect to |
95 | 107 | """ |
108 | if not client_redirect_url: | |
109 | # Some SAML identity providers (e.g. Google) require a | |
110 | # RelayState parameter on requests, so pass in a dummy redirect URL | |
111 | # (which will never get used). | |
112 | client_redirect_url = b"unused" | |
113 | ||
96 | 114 | reqid, info = self._saml_client.prepare_for_authenticate( |
97 | 115 | entityid=self._saml_idp_entityid, relay_state=client_redirect_url |
98 | 116 | ) |
209 | 227 | return |
210 | 228 | |
211 | 229 | return await self._sso_handler.complete_sso_ui_auth_request( |
212 | self._auth_provider_id, | |
230 | self.idp_id, | |
213 | 231 | remote_user_id, |
214 | 232 | current_session.ui_auth_session_id, |
215 | 233 | request, |
305 | 323 | return None |
306 | 324 | |
307 | 325 | await self._sso_handler.complete_sso_login_request( |
308 | self._auth_provider_id, | |
326 | self.idp_id, | |
309 | 327 | remote_user_id, |
310 | 328 | request, |
311 | 329 | client_redirect_url, |
11 | 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. |
12 | 12 | # See the License for the specific language governing permissions and |
13 | 13 | # limitations under the License. |
14 | import abc | |
14 | 15 | import logging |
15 | from typing import TYPE_CHECKING, Awaitable, Callable, Dict, List, Optional | |
16 | from typing import TYPE_CHECKING, Awaitable, Callable, Dict, List, Mapping, Optional | |
17 | from urllib.parse import urlencode | |
16 | 18 | |
17 | 19 | import attr |
18 | from typing_extensions import NoReturn | |
20 | from typing_extensions import NoReturn, Protocol | |
19 | 21 | |
20 | 22 | from twisted.web.http import Request |
21 | 23 | |
22 | from synapse.api.errors import RedirectException, SynapseError | |
24 | from synapse.api.constants import LoginType | |
25 | from synapse.api.errors import Codes, RedirectException, SynapseError | |
26 | from synapse.handlers.ui_auth import UIAuthSessionDataConstants | |
27 | from synapse.http import get_request_user_agent | |
23 | 28 | from synapse.http.server import respond_with_html |
24 | 29 | from synapse.http.site import SynapseRequest |
25 | 30 | from synapse.types import JsonDict, UserID, contains_invalid_mxid_characters |
37 | 42 | |
38 | 43 | Note that the msg that is raised is shown to end-users. |
39 | 44 | """ |
45 | ||
46 | ||
47 | class SsoIdentityProvider(Protocol): | |
48 | """Abstract base class to be implemented by SSO Identity Providers | |
49 | ||
50 | An Identity Provider, or IdP, is an external HTTP service which authenticates a user | |
51 | to say whether they should be allowed to log in, or perform a given action. | |
52 | ||
53 | Synapse supports various implementations of IdPs, including OpenID Connect, SAML, | |
54 | and CAS. | |
55 | ||
56 | The main entry point is `handle_redirect_request`, which should return a URI to | |
57 | redirect the user's browser to the IdP's authentication page. | |
58 | ||
59 | Each IdP should be registered with the SsoHandler via | |
60 | `hs.get_sso_handler().register_identity_provider()`, so that requests to | |
61 | `/_matrix/client/r0/login/sso/redirect` can be correctly dispatched. | |
62 | """ | |
63 | ||
64 | @property | |
65 | @abc.abstractmethod | |
66 | def idp_id(self) -> str: | |
67 | """A unique identifier for this SSO provider | |
68 | ||
69 | Eg, "saml", "cas", "github" | |
70 | """ | |
71 | ||
72 | @property | |
73 | @abc.abstractmethod | |
74 | def idp_name(self) -> str: | |
75 | """User-facing name for this provider""" | |
76 | ||
77 | @property | |
78 | def idp_icon(self) -> Optional[str]: | |
79 | """Optional MXC URI for user-facing icon""" | |
80 | return None | |
81 | ||
82 | @abc.abstractmethod | |
83 | async def handle_redirect_request( | |
84 | self, | |
85 | request: SynapseRequest, | |
86 | client_redirect_url: Optional[bytes], | |
87 | ui_auth_session_id: Optional[str] = None, | |
88 | ) -> str: | |
89 | """Handle an incoming request to /login/sso/redirect | |
90 | ||
91 | Args: | |
92 | request: the incoming HTTP request | |
93 | client_redirect_url: the URL that we should redirect the | |
94 | client to after login (or None for UI Auth). | |
95 | ui_auth_session_id: The session ID of the ongoing UI Auth (or | |
96 | None if this is a login). | |
97 | ||
98 | Returns: | |
99 | URL to redirect to | |
100 | """ | |
101 | raise NotImplementedError() | |
40 | 102 | |
41 | 103 | |
42 | 104 | @attr.s |
90 | 152 | self._store = hs.get_datastore() |
91 | 153 | self._server_name = hs.hostname |
92 | 154 | self._registration_handler = hs.get_registration_handler() |
155 | self._auth_handler = hs.get_auth_handler() | |
93 | 156 | self._error_template = hs.config.sso_error_template |
94 | self._auth_handler = hs.get_auth_handler() | |
157 | self._bad_user_template = hs.config.sso_auth_bad_user_template | |
158 | ||
159 | # The following template is shown after a successful user interactive | |
160 | # authentication session. It tells the user they can close the window. | |
161 | self._sso_auth_success_template = hs.config.sso_auth_success_template | |
95 | 162 | |
96 | 163 | # a lock on the mappings |
97 | 164 | self._mapping_lock = Linearizer(name="sso_user_mapping", clock=hs.get_clock()) |
98 | 165 | |
99 | 166 | # a map from session id to session data |
100 | 167 | self._username_mapping_sessions = {} # type: Dict[str, UsernameMappingSession] |
168 | ||
169 | # map from idp_id to SsoIdentityProvider | |
170 | self._identity_providers = {} # type: Dict[str, SsoIdentityProvider] | |
171 | ||
172 | def register_identity_provider(self, p: SsoIdentityProvider): | |
173 | p_id = p.idp_id | |
174 | assert p_id not in self._identity_providers | |
175 | self._identity_providers[p_id] = p | |
176 | ||
177 | def get_identity_providers(self) -> Mapping[str, SsoIdentityProvider]: | |
178 | """Get the configured identity providers""" | |
179 | return self._identity_providers | |
180 | ||
181 | async def get_identity_providers_for_user( | |
182 | self, user_id: str | |
183 | ) -> Mapping[str, SsoIdentityProvider]: | |
184 | """Get the SsoIdentityProviders which a user has used | |
185 | ||
186 | Given a user id, get the identity providers that that user has used to log in | |
187 | with in the past (and thus could use to re-identify themselves for UI Auth). | |
188 | ||
189 | Args: | |
190 | user_id: MXID of user to look up | |
191 | ||
192 | Raises: | |
193 | a map of idp_id to SsoIdentityProvider | |
194 | """ | |
195 | external_ids = await self._store.get_external_ids_by_user(user_id) | |
196 | ||
197 | valid_idps = {} | |
198 | for idp_id, _ in external_ids: | |
199 | idp = self._identity_providers.get(idp_id) | |
200 | if not idp: | |
201 | logger.warning( | |
202 | "User %r has an SSO mapping for IdP %r, but this is no longer " | |
203 | "configured.", | |
204 | user_id, | |
205 | idp_id, | |
206 | ) | |
207 | else: | |
208 | valid_idps[idp_id] = idp | |
209 | ||
210 | return valid_idps | |
101 | 211 | |
102 | 212 | def render_error( |
103 | 213 | self, |
122 | 232 | error=error, error_description=error_description |
123 | 233 | ) |
124 | 234 | respond_with_html(request, code, html) |
235 | ||
236 | async def handle_redirect_request( | |
237 | self, request: SynapseRequest, client_redirect_url: bytes, | |
238 | ) -> str: | |
239 | """Handle a request to /login/sso/redirect | |
240 | ||
241 | Args: | |
242 | request: incoming HTTP request | |
243 | client_redirect_url: the URL that we should redirect the | |
244 | client to after login. | |
245 | ||
246 | Returns: | |
247 | the URI to redirect to | |
248 | """ | |
249 | if not self._identity_providers: | |
250 | raise SynapseError( | |
251 | 400, "Homeserver not configured for SSO.", errcode=Codes.UNRECOGNIZED | |
252 | ) | |
253 | ||
254 | # if we only have one auth provider, redirect to it directly | |
255 | if len(self._identity_providers) == 1: | |
256 | ap = next(iter(self._identity_providers.values())) | |
257 | return await ap.handle_redirect_request(request, client_redirect_url) | |
258 | ||
259 | # otherwise, redirect to the IDP picker | |
260 | return "/_synapse/client/pick_idp?" + urlencode( | |
261 | (("redirectUrl", client_redirect_url),) | |
262 | ) | |
125 | 263 | |
126 | 264 | async def get_sso_user_by_remote_user_id( |
127 | 265 | self, auth_provider_id: str, remote_user_id: str |
267 | 405 | attributes, |
268 | 406 | auth_provider_id, |
269 | 407 | remote_user_id, |
270 | request.get_user_agent(""), | |
408 | get_request_user_agent(request), | |
271 | 409 | request.getClientIP(), |
272 | 410 | ) |
273 | 411 | |
450 | 588 | auth_provider_id, remote_user_id, |
451 | 589 | ) |
452 | 590 | |
591 | user_id_to_verify = await self._auth_handler.get_session_data( | |
592 | ui_auth_session_id, UIAuthSessionDataConstants.REQUEST_USER_ID | |
593 | ) # type: str | |
594 | ||
453 | 595 | if not user_id: |
454 | 596 | logger.warning( |
455 | 597 | "Remote user %s/%s has not previously logged in here: UIA will fail", |
456 | 598 | auth_provider_id, |
457 | 599 | remote_user_id, |
458 | 600 | ) |
459 | # Let the UIA flow handle this the same as if they presented creds for a | |
460 | # different user. | |
461 | user_id = "" | |
462 | ||
463 | await self._auth_handler.complete_sso_ui_auth( | |
464 | user_id, ui_auth_session_id, request | |
465 | ) | |
601 | elif user_id != user_id_to_verify: | |
602 | logger.warning( | |
603 | "Remote user %s/%s mapped onto incorrect user %s: UIA will fail", | |
604 | auth_provider_id, | |
605 | remote_user_id, | |
606 | user_id, | |
607 | ) | |
608 | else: | |
609 | # success! | |
610 | # Mark the stage of the authentication as successful. | |
611 | await self._store.mark_ui_auth_stage_complete( | |
612 | ui_auth_session_id, LoginType.SSO, user_id | |
613 | ) | |
614 | ||
615 | # Render the HTML confirmation page and return. | |
616 | html = self._sso_auth_success_template | |
617 | respond_with_html(request, 200, html) | |
618 | return | |
619 | ||
620 | # the user_id didn't match: mark the stage of the authentication as unsuccessful | |
621 | await self._store.mark_ui_auth_stage_complete( | |
622 | ui_auth_session_id, LoginType.SSO, "" | |
623 | ) | |
624 | ||
625 | # render an error page. | |
626 | html = self._bad_user_template.render( | |
627 | server_name=self._server_name, user_id_to_verify=user_id_to_verify, | |
628 | ) | |
629 | respond_with_html(request, 200, html) | |
466 | 630 | |
467 | 631 | async def check_username_availability( |
468 | 632 | self, localpart: str, session_id: str, |
533 | 697 | attributes, |
534 | 698 | session.auth_provider_id, |
535 | 699 | session.remote_user_id, |
536 | request.get_user_agent(""), | |
700 | get_request_user_agent(request), | |
537 | 701 | request.getClientIP(), |
538 | 702 | ) |
539 | 703 |
19 | 19 | """ |
20 | 20 | |
21 | 21 | from synapse.handlers.ui_auth.checkers import INTERACTIVE_AUTH_CHECKERS # noqa: F401 |
22 | ||
23 | ||
24 | class UIAuthSessionDataConstants: | |
25 | """Constants for use with AuthHandler.set_session_data""" | |
26 | ||
27 | # used during registration and password reset to store a hashed copy of the | |
28 | # password, so that the client does not need to submit it each time. | |
29 | PASSWORD_HASH = "password_hash" | |
30 | ||
31 | # used during registration to store the mxid of the registered user | |
32 | REGISTERED_USER_ID = "registered_user_id" | |
33 | ||
34 | # used by validate_user_via_ui_auth to store the mxid of the user we are validating | |
35 | # for. | |
36 | REQUEST_USER_ID = "request_user_id" |
16 | 16 | |
17 | 17 | from twisted.internet import task |
18 | 18 | from twisted.web.client import FileBodyProducer |
19 | from twisted.web.iweb import IRequest | |
19 | 20 | |
20 | 21 | from synapse.api.errors import SynapseError |
21 | 22 | |
49 | 50 | FileBodyProducer.stopProducing(self) |
50 | 51 | except task.TaskStopped: |
51 | 52 | pass |
53 | ||
54 | ||
55 | def get_request_user_agent(request: IRequest, default: str = "") -> str: | |
56 | """Return the last User-Agent header, or the given default. | |
57 | """ | |
58 | # There could be raw utf-8 bytes in the User-Agent header. | |
59 | ||
60 | # N.B. if you don't do this, the logger explodes cryptically | |
61 | # with maximum recursion trying to log errors about | |
62 | # the charset problem. | |
63 | # c.f. https://github.com/matrix-org/synapse/issues/3471 | |
64 | ||
65 | h = request.getHeader(b"User-Agent") | |
66 | return h.decode("ascii", "replace") if h else default |
31 | 31 | |
32 | 32 | import treq |
33 | 33 | from canonicaljson import encode_canonical_json |
34 | from netaddr import IPAddress, IPSet | |
34 | from netaddr import AddrFormatError, IPAddress, IPSet | |
35 | 35 | from prometheus_client import Counter |
36 | 36 | from zope.interface import implementer, provider |
37 | 37 | |
260 | 260 | |
261 | 261 | try: |
262 | 262 | ip_address = IPAddress(h.hostname) |
263 | ||
263 | except AddrFormatError: | |
264 | # Not an IP | |
265 | pass | |
266 | else: | |
264 | 267 | if check_against_blacklist( |
265 | 268 | ip_address, self._ip_whitelist, self._ip_blacklist |
266 | 269 | ): |
267 | 270 | logger.info("Blocking access to %s due to blacklist" % (ip_address,)) |
268 | 271 | e = SynapseError(403, "IP address blocked by IP blacklist entry") |
269 | 272 | return defer.fail(Failure(e)) |
270 | except Exception: | |
271 | # Not an IP | |
272 | pass | |
273 | 273 | |
274 | 274 | return self._agent.request( |
275 | 275 | method, uri, headers=headers, bodyProducer=bodyProducer |
723 | 723 | read_body_with_max_size(response, output_stream, max_size) |
724 | 724 | ) |
725 | 725 | except BodyExceededMaxSize: |
726 | SynapseError( | |
726 | raise SynapseError( | |
727 | 727 | 502, |
728 | 728 | "Requested file is too large > %r bytes" % (max_size,), |
729 | 729 | Codes.TOO_LARGE, |
765 | 765 | self.max_size = max_size |
766 | 766 | |
767 | 767 | def dataReceived(self, data: bytes) -> None: |
768 | # If the deferred was called, bail early. | |
769 | if self.deferred.called: | |
770 | return | |
771 | ||
768 | 772 | self.stream.write(data) |
769 | 773 | self.length += len(data) |
774 | # The first time the maximum size is exceeded, error and cancel the | |
775 | # connection. dataReceived might be called again if data was received | |
776 | # in the meantime. | |
770 | 777 | if self.max_size is not None and self.length >= self.max_size: |
771 | 778 | self.deferred.errback(BodyExceededMaxSize()) |
772 | self.deferred = defer.Deferred() | |
773 | 779 | self.transport.loseConnection() |
774 | 780 | |
775 | 781 | def connectionLost(self, reason: Failure) -> None: |
782 | # If the maximum size was already exceeded, there's nothing to do. | |
783 | if self.deferred.called: | |
784 | return | |
785 | ||
776 | 786 | if reason.check(ResponseDone): |
777 | 787 | self.deferred.callback(self.length) |
778 | 788 | elif reason.check(PotentialDataLoss): |
0 | # -*- coding: utf-8 -*- | |
1 | # Copyright 2014-2016 OpenMarket Ltd | |
2 | # | |
3 | # Licensed under the Apache License, Version 2.0 (the "License"); | |
4 | # you may not use this file except in compliance with the License. | |
5 | # You may obtain a copy of the License at | |
6 | # | |
7 | # http://www.apache.org/licenses/LICENSE-2.0 | |
8 | # | |
9 | # Unless required by applicable law or agreed to in writing, software | |
10 | # distributed under the License is distributed on an "AS IS" BASIS, | |
11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |
12 | # See the License for the specific language governing permissions and | |
13 | # limitations under the License. | |
14 | import logging | |
15 | import re | |
16 | ||
17 | logger = logging.getLogger(__name__) | |
18 | ||
19 | ||
20 | def parse_server_name(server_name): | |
21 | """Split a server name into host/port parts. | |
22 | ||
23 | Args: | |
24 | server_name (str): server name to parse | |
25 | ||
26 | Returns: | |
27 | Tuple[str, int|None]: host/port parts. | |
28 | ||
29 | Raises: | |
30 | ValueError if the server name could not be parsed. | |
31 | """ | |
32 | try: | |
33 | if server_name[-1] == "]": | |
34 | # ipv6 literal, hopefully | |
35 | return server_name, None | |
36 | ||
37 | domain_port = server_name.rsplit(":", 1) | |
38 | domain = domain_port[0] | |
39 | port = int(domain_port[1]) if domain_port[1:] else None | |
40 | return domain, port | |
41 | except Exception: | |
42 | raise ValueError("Invalid server name '%s'" % server_name) | |
43 | ||
44 | ||
45 | VALID_HOST_REGEX = re.compile("\\A[0-9a-zA-Z.-]+\\Z") | |
46 | ||
47 | ||
48 | def parse_and_validate_server_name(server_name): | |
49 | """Split a server name into host/port parts and do some basic validation. | |
50 | ||
51 | Args: | |
52 | server_name (str): server name to parse | |
53 | ||
54 | Returns: | |
55 | Tuple[str, int|None]: host/port parts. | |
56 | ||
57 | Raises: | |
58 | ValueError if the server name could not be parsed. | |
59 | """ | |
60 | host, port = parse_server_name(server_name) | |
61 | ||
62 | # these tests don't need to be bulletproof as we'll find out soon enough | |
63 | # if somebody is giving us invalid data. What we *do* need is to be sure | |
64 | # that nobody is sneaking IP literals in that look like hostnames, etc. | |
65 | ||
66 | # look for ipv6 literals | |
67 | if host[0] == "[": | |
68 | if host[-1] != "]": | |
69 | raise ValueError("Mismatched [...] in server name '%s'" % (server_name,)) | |
70 | return host, port | |
71 | ||
72 | # otherwise it should only be alphanumerics. | |
73 | if not VALID_HOST_REGEX.match(host): | |
74 | raise ValueError( | |
75 | "Server name '%s' contains invalid characters" % (server_name,) | |
76 | ) | |
77 | ||
78 | return host, port |
101 | 101 | pool=self._pool, |
102 | 102 | contextFactory=tls_client_options_factory, |
103 | 103 | ), |
104 | self._reactor, | |
105 | 104 | ip_blacklist=ip_blacklist, |
106 | 105 | ), |
107 | 106 | user_agent=self.user_agent, |
173 | 173 | d = timeout_deferred(d, timeout=timeout_sec, reactor=reactor) |
174 | 174 | |
175 | 175 | body = await make_deferred_yieldable(d) |
176 | except ValueError as e: | |
177 | # The JSON content was invalid. | |
178 | logger.warning( | |
179 | "{%s} [%s] Failed to parse JSON response - %s %s", | |
180 | request.txn_id, | |
181 | request.destination, | |
182 | request.method, | |
183 | request.uri.decode("ascii"), | |
184 | ) | |
185 | raise RequestSendFailed(e, can_retry=False) from e | |
176 | 186 | except defer.TimeoutError as e: |
177 | 187 | logger.warning( |
178 | 188 | "{%s} [%s] Timed out reading response - %s %s", |
985 | 995 | logger.warning( |
986 | 996 | "{%s} [%s] %s", request.txn_id, request.destination, msg, |
987 | 997 | ) |
988 | SynapseError(502, msg, Codes.TOO_LARGE) | |
998 | raise SynapseError(502, msg, Codes.TOO_LARGE) | |
989 | 999 | except Exception as e: |
990 | 1000 | logger.warning( |
991 | 1001 | "{%s} [%s] Error reading response: %s", |
19 | 19 | from twisted.web.server import Request, Site |
20 | 20 | |
21 | 21 | from synapse.config.server import ListenerConfig |
22 | from synapse.http import redact_uri | |
22 | from synapse.http import get_request_user_agent, redact_uri | |
23 | 23 | from synapse.http.request_metrics import RequestMetrics, requests_counter |
24 | 24 | from synapse.logging.context import LoggingContext, PreserveLoggingContext |
25 | 25 | from synapse.types import Requester |
111 | 111 | if isinstance(method, bytes): |
112 | 112 | method = self.method.decode("ascii") |
113 | 113 | return method |
114 | ||
115 | def get_user_agent(self, default: str) -> str: | |
116 | """Return the last User-Agent header, or the given default. | |
117 | """ | |
118 | user_agent = self.requestHeaders.getRawHeaders(b"User-Agent", [None])[-1] | |
119 | if user_agent is None: | |
120 | return default | |
121 | ||
122 | return user_agent.decode("ascii", "replace") | |
123 | 114 | |
124 | 115 | def render(self, resrc): |
125 | 116 | # this is called once a Resource has been found to serve the request; in our |
291 | 282 | # and can see that we're doing something wrong. |
292 | 283 | authenticated_entity = repr(self.requester) # type: ignore[unreachable] |
293 | 284 | |
294 | # ...or could be raw utf-8 bytes in the User-Agent header. | |
295 | # N.B. if you don't do this, the logger explodes cryptically | |
296 | # with maximum recursion trying to log errors about | |
297 | # the charset problem. | |
298 | # c.f. https://github.com/matrix-org/synapse/issues/3471 | |
299 | user_agent = self.get_user_agent("-") | |
285 | user_agent = get_request_user_agent(self, "-") | |
300 | 286 | |
301 | 287 | code = str(self.code) |
302 | 288 | if not self.finished: |
251 | 251 | "scope", |
252 | 252 | ] |
253 | 253 | |
254 | def __init__(self, name=None, parent_context=None, request=None) -> None: | |
254 | def __init__( | |
255 | self, | |
256 | name: Optional[str] = None, | |
257 | parent_context: "Optional[LoggingContext]" = None, | |
258 | request: Optional[str] = None, | |
259 | ) -> None: | |
255 | 260 | self.previous_context = current_context() |
256 | 261 | self.name = name |
257 | 262 | |
535 | 540 | def __init__(self, request: str = ""): |
536 | 541 | self._default_request = request |
537 | 542 | |
538 | def filter(self, record) -> Literal[True]: | |
543 | def filter(self, record: logging.LogRecord) -> Literal[True]: | |
539 | 544 | """Add each fields from the logging contexts to the record. |
540 | 545 | Returns: |
541 | 546 | True to include the record in the log output. |
542 | 547 | """ |
543 | 548 | context = current_context() |
544 | record.request = self._default_request | |
549 | record.request = self._default_request # type: ignore | |
545 | 550 | |
546 | 551 | # context should never be None, but if it somehow ends up being, then |
547 | 552 | # we end up in a death spiral of infinite loops, so let's check, for |
548 | 553 | # robustness' sake. |
549 | 554 | if context is not None: |
550 | 555 | # Logging is interested in the request. |
551 | record.request = context.request | |
556 | record.request = context.request # type: ignore | |
552 | 557 | |
553 | 558 | return True |
554 | 559 | |
615 | 620 | return current |
616 | 621 | |
617 | 622 | |
618 | def nested_logging_context( | |
619 | suffix: str, parent_context: Optional[LoggingContext] = None | |
620 | ) -> LoggingContext: | |
623 | def nested_logging_context(suffix: str) -> LoggingContext: | |
621 | 624 | """Creates a new logging context as a child of another. |
622 | 625 | |
623 | 626 | The nested logging context will have a 'request' made up of the parent context's |
631 | 634 | # ... do stuff |
632 | 635 | |
633 | 636 | Args: |
634 | suffix (str): suffix to add to the parent context's 'request'. | |
635 | parent_context (LoggingContext|None): parent context. Will use the current context | |
636 | if None. | |
637 | suffix: suffix to add to the parent context's 'request'. | |
637 | 638 | |
638 | 639 | Returns: |
639 | 640 | LoggingContext: new logging context. |
640 | 641 | """ |
641 | if parent_context is not None: | |
642 | context = parent_context # type: LoggingContextOrSentinel | |
642 | curr_context = current_context() | |
643 | if not curr_context: | |
644 | logger.warning( | |
645 | "Starting nested logging context from sentinel context: metrics will be lost" | |
646 | ) | |
647 | parent_context = None | |
648 | prefix = "" | |
643 | 649 | else: |
644 | context = current_context() | |
645 | return LoggingContext( | |
646 | parent_context=context, request=str(context.request) + "-" + suffix | |
647 | ) | |
650 | assert isinstance(curr_context, LoggingContext) | |
651 | parent_context = curr_context | |
652 | prefix = str(parent_context.request) | |
653 | return LoggingContext(parent_context=parent_context, request=prefix + "-" + suffix) | |
648 | 654 | |
649 | 655 | |
650 | 656 | def preserve_fn(f): |
821 | 827 | Deferred: A Deferred which fires a callback with the result of `f`, or an |
822 | 828 | errback if `f` throws an exception. |
823 | 829 | """ |
824 | logcontext = current_context() | |
830 | curr_context = current_context() | |
831 | if not curr_context: | |
832 | logger.warning( | |
833 | "Calling defer_to_threadpool from sentinel context: metrics will be lost" | |
834 | ) | |
835 | parent_context = None | |
836 | else: | |
837 | assert isinstance(curr_context, LoggingContext) | |
838 | parent_context = curr_context | |
825 | 839 | |
826 | 840 | def g(): |
827 | with LoggingContext(parent_context=logcontext): | |
841 | with LoggingContext(parent_context=parent_context): | |
828 | 842 | return f(*args, **kwargs) |
829 | 843 | |
830 | 844 | return make_deferred_yieldable(threads.deferToThreadPool(reactor, threadpool, g)) |
395 | 395 | |
396 | 396 | Will wake up all listeners for the given users and rooms. |
397 | 397 | """ |
398 | with PreserveLoggingContext(): | |
399 | with Measure(self.clock, "on_new_event"): | |
400 | user_streams = set() | |
401 | ||
402 | for user in users: | |
403 | user_stream = self.user_to_user_stream.get(str(user)) | |
404 | if user_stream is not None: | |
405 | user_streams.add(user_stream) | |
406 | ||
407 | for room in rooms: | |
408 | user_streams |= self.room_to_user_streams.get(room, set()) | |
409 | ||
410 | time_now_ms = self.clock.time_msec() | |
411 | for user_stream in user_streams: | |
412 | try: | |
413 | user_stream.notify(stream_key, new_token, time_now_ms) | |
414 | except Exception: | |
415 | logger.exception("Failed to notify listener") | |
416 | ||
417 | self.notify_replication() | |
418 | ||
419 | # Notify appservices | |
420 | self._notify_app_services_ephemeral( | |
421 | stream_key, new_token, users, | |
422 | ) | |
398 | with Measure(self.clock, "on_new_event"): | |
399 | user_streams = set() | |
400 | ||
401 | for user in users: | |
402 | user_stream = self.user_to_user_stream.get(str(user)) | |
403 | if user_stream is not None: | |
404 | user_streams.add(user_stream) | |
405 | ||
406 | for room in rooms: | |
407 | user_streams |= self.room_to_user_streams.get(room, set()) | |
408 | ||
409 | time_now_ms = self.clock.time_msec() | |
410 | for user_stream in user_streams: | |
411 | try: | |
412 | user_stream.notify(stream_key, new_token, time_now_ms) | |
413 | except Exception: | |
414 | logger.exception("Failed to notify listener") | |
415 | ||
416 | self.notify_replication() | |
417 | ||
418 | # Notify appservices | |
419 | self._notify_app_services_ephemeral( | |
420 | stream_key, new_token, users, | |
421 | ) | |
423 | 422 | |
424 | 423 | def on_new_replication_data(self) -> None: |
425 | 424 | """Used to inform replication listeners that something has happened |
202 | 202 | |
203 | 203 | condition_cache = {} # type: Dict[str, bool] |
204 | 204 | |
205 | # If the event is not a state event check if any users ignore the sender. | |
206 | if not event.is_state(): | |
207 | ignorers = await self.store.ignored_by(event.sender) | |
208 | else: | |
209 | ignorers = set() | |
210 | ||
205 | 211 | for uid, rules in rules_by_user.items(): |
206 | 212 | if event.sender == uid: |
207 | 213 | continue |
208 | 214 | |
209 | if not event.is_state(): | |
210 | is_ignored = await self.store.is_ignored_by(event.sender, uid) | |
211 | if is_ignored: | |
212 | continue | |
215 | if uid in ignorers: | |
216 | continue | |
213 | 217 | |
214 | 218 | display_name = None |
215 | 219 | profile_info = room_members.get(uid) |
85 | 85 | |
86 | 86 | CONDITIONAL_REQUIREMENTS = { |
87 | 87 | "matrix-synapse-ldap3": ["matrix-synapse-ldap3>=0.1"], |
88 | # we use execute_batch, which arrived in psycopg 2.7. | |
89 | "postgres": ["psycopg2>=2.7"], | |
88 | # we use execute_values with the fetch param, which arrived in psycopg 2.8. | |
89 | "postgres": ["psycopg2>=2.8"], | |
90 | 90 | # ACME support is required to provision TLS certificates from authorities |
91 | 91 | # that use the protocol, such as Let's Encrypt. |
92 | 92 | "acme": [ |
14 | 14 | |
15 | 15 | from synapse.http.server import JsonResource |
16 | 16 | from synapse.replication.http import ( |
17 | account_data, | |
17 | 18 | devices, |
18 | 19 | federation, |
19 | 20 | login, |
39 | 40 | presence.register_servlets(hs, self) |
40 | 41 | membership.register_servlets(hs, self) |
41 | 42 | streams.register_servlets(hs, self) |
43 | account_data.register_servlets(hs, self) | |
42 | 44 | |
43 | 45 | # The following can't currently be instantiated on workers. |
44 | 46 | if hs.config.worker.worker_app is None: |
176 | 176 | |
177 | 177 | @trace(opname="outgoing_replication_request") |
178 | 178 | @outgoing_gauge.track_inprogress() |
179 | async def send_request(instance_name="master", **kwargs): | |
179 | async def send_request(*, instance_name="master", **kwargs): | |
180 | 180 | if instance_name == local_instance_name: |
181 | 181 | raise Exception("Trying to send HTTP request to self") |
182 | 182 | if instance_name == "master": |
0 | # -*- coding: utf-8 -*- | |
1 | # Copyright 2021 The Matrix.org Foundation C.I.C. | |
2 | # | |
3 | # Licensed under the Apache License, Version 2.0 (the "License"); | |
4 | # you may not use this file except in compliance with the License. | |
5 | # You may obtain a copy of the License at | |
6 | # | |
7 | # http://www.apache.org/licenses/LICENSE-2.0 | |
8 | # | |
9 | # Unless required by applicable law or agreed to in writing, software | |
10 | # distributed under the License is distributed on an "AS IS" BASIS, | |
11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |
12 | # See the License for the specific language governing permissions and | |
13 | # limitations under the License. | |
14 | ||
15 | import logging | |
16 | ||
17 | from synapse.http.servlet import parse_json_object_from_request | |
18 | from synapse.replication.http._base import ReplicationEndpoint | |
19 | ||
20 | logger = logging.getLogger(__name__) | |
21 | ||
22 | ||
23 | class ReplicationUserAccountDataRestServlet(ReplicationEndpoint): | |
24 | """Add user account data on the appropriate account data worker. | |
25 | ||
26 | Request format: | |
27 | ||
28 | POST /_synapse/replication/add_user_account_data/:user_id/:type | |
29 | ||
30 | { | |
31 | "content": { ... }, | |
32 | } | |
33 | ||
34 | """ | |
35 | ||
36 | NAME = "add_user_account_data" | |
37 | PATH_ARGS = ("user_id", "account_data_type") | |
38 | CACHE = False | |
39 | ||
40 | def __init__(self, hs): | |
41 | super().__init__(hs) | |
42 | ||
43 | self.handler = hs.get_account_data_handler() | |
44 | self.clock = hs.get_clock() | |
45 | ||
46 | @staticmethod | |
47 | async def _serialize_payload(user_id, account_data_type, content): | |
48 | payload = { | |
49 | "content": content, | |
50 | } | |
51 | ||
52 | return payload | |
53 | ||
54 | async def _handle_request(self, request, user_id, account_data_type): | |
55 | content = parse_json_object_from_request(request) | |
56 | ||
57 | max_stream_id = await self.handler.add_account_data_for_user( | |
58 | user_id, account_data_type, content["content"] | |
59 | ) | |
60 | ||
61 | return 200, {"max_stream_id": max_stream_id} | |
62 | ||
63 | ||
64 | class ReplicationRoomAccountDataRestServlet(ReplicationEndpoint): | |
65 | """Add room account data on the appropriate account data worker. | |
66 | ||
67 | Request format: | |
68 | ||
69 | POST /_synapse/replication/add_room_account_data/:user_id/:room_id/:account_data_type | |
70 | ||
71 | { | |
72 | "content": { ... }, | |
73 | } | |
74 | ||
75 | """ | |
76 | ||
77 | NAME = "add_room_account_data" | |
78 | PATH_ARGS = ("user_id", "room_id", "account_data_type") | |
79 | CACHE = False | |
80 | ||
81 | def __init__(self, hs): | |
82 | super().__init__(hs) | |
83 | ||
84 | self.handler = hs.get_account_data_handler() | |
85 | self.clock = hs.get_clock() | |
86 | ||
87 | @staticmethod | |
88 | async def _serialize_payload(user_id, room_id, account_data_type, content): | |
89 | payload = { | |
90 | "content": content, | |
91 | } | |
92 | ||
93 | return payload | |
94 | ||
95 | async def _handle_request(self, request, user_id, room_id, account_data_type): | |
96 | content = parse_json_object_from_request(request) | |
97 | ||
98 | max_stream_id = await self.handler.add_account_data_to_room( | |
99 | user_id, room_id, account_data_type, content["content"] | |
100 | ) | |
101 | ||
102 | return 200, {"max_stream_id": max_stream_id} | |
103 | ||
104 | ||
105 | class ReplicationAddTagRestServlet(ReplicationEndpoint): | |
106 | """Add tag on the appropriate account data worker. | |
107 | ||
108 | Request format: | |
109 | ||
110 | POST /_synapse/replication/add_tag/:user_id/:room_id/:tag | |
111 | ||
112 | { | |
113 | "content": { ... }, | |
114 | } | |
115 | ||
116 | """ | |
117 | ||
118 | NAME = "add_tag" | |
119 | PATH_ARGS = ("user_id", "room_id", "tag") | |
120 | CACHE = False | |
121 | ||
122 | def __init__(self, hs): | |
123 | super().__init__(hs) | |
124 | ||
125 | self.handler = hs.get_account_data_handler() | |
126 | self.clock = hs.get_clock() | |
127 | ||
128 | @staticmethod | |
129 | async def _serialize_payload(user_id, room_id, tag, content): | |
130 | payload = { | |
131 | "content": content, | |
132 | } | |
133 | ||
134 | return payload | |
135 | ||
136 | async def _handle_request(self, request, user_id, room_id, tag): | |
137 | content = parse_json_object_from_request(request) | |
138 | ||
139 | max_stream_id = await self.handler.add_tag_to_room( | |
140 | user_id, room_id, tag, content["content"] | |
141 | ) | |
142 | ||
143 | return 200, {"max_stream_id": max_stream_id} | |
144 | ||
145 | ||
146 | class ReplicationRemoveTagRestServlet(ReplicationEndpoint): | |
147 | """Remove tag on the appropriate account data worker. | |
148 | ||
149 | Request format: | |
150 | ||
151 | POST /_synapse/replication/remove_tag/:user_id/:room_id/:tag | |
152 | ||
153 | {} | |
154 | ||
155 | """ | |
156 | ||
157 | NAME = "remove_tag" | |
158 | PATH_ARGS = ( | |
159 | "user_id", | |
160 | "room_id", | |
161 | "tag", | |
162 | ) | |
163 | CACHE = False | |
164 | ||
165 | def __init__(self, hs): | |
166 | super().__init__(hs) | |
167 | ||
168 | self.handler = hs.get_account_data_handler() | |
169 | self.clock = hs.get_clock() | |
170 | ||
171 | @staticmethod | |
172 | async def _serialize_payload(user_id, room_id, tag): | |
173 | ||
174 | return {} | |
175 | ||
176 | async def _handle_request(self, request, user_id, room_id, tag): | |
177 | max_stream_id = await self.handler.remove_tag_from_room(user_id, room_id, tag,) | |
178 | ||
179 | return 200, {"max_stream_id": max_stream_id} | |
180 | ||
181 | ||
182 | def register_servlets(hs, http_server): | |
183 | ReplicationUserAccountDataRestServlet(hs).register(http_server) | |
184 | ReplicationRoomAccountDataRestServlet(hs).register(http_server) | |
185 | ReplicationAddTagRestServlet(hs).register(http_server) | |
186 | ReplicationRemoveTagRestServlet(hs).register(http_server) |
32 | 32 | database, |
33 | 33 | stream_name="caches", |
34 | 34 | instance_name=hs.get_instance_name(), |
35 | table="cache_invalidation_stream_by_instance", | |
36 | instance_column="instance_name", | |
37 | id_column="stream_id", | |
35 | tables=[ | |
36 | ( | |
37 | "cache_invalidation_stream_by_instance", | |
38 | "instance_name", | |
39 | "stream_id", | |
40 | ) | |
41 | ], | |
38 | 42 | sequence_name="cache_invalidation_stream_seq", |
39 | 43 | writers=[], |
40 | 44 | ) # type: Optional[MultiWriterIdGenerator] |
14 | 14 | # limitations under the License. |
15 | 15 | |
16 | 16 | from synapse.replication.slave.storage._base import BaseSlavedStore |
17 | from synapse.replication.slave.storage._slaved_id_tracker import SlavedIdTracker | |
18 | from synapse.replication.tcp.streams import AccountDataStream, TagAccountDataStream | |
19 | from synapse.storage.database import DatabasePool | |
20 | 17 | from synapse.storage.databases.main.account_data import AccountDataWorkerStore |
21 | 18 | from synapse.storage.databases.main.tags import TagsWorkerStore |
22 | 19 | |
23 | 20 | |
24 | 21 | class SlavedAccountDataStore(TagsWorkerStore, AccountDataWorkerStore, BaseSlavedStore): |
25 | def __init__(self, database: DatabasePool, db_conn, hs): | |
26 | self._account_data_id_gen = SlavedIdTracker( | |
27 | db_conn, | |
28 | "account_data", | |
29 | "stream_id", | |
30 | extra_tables=[ | |
31 | ("room_account_data", "stream_id"), | |
32 | ("room_tags_revisions", "stream_id"), | |
33 | ], | |
34 | ) | |
35 | ||
36 | super().__init__(database, db_conn, hs) | |
37 | ||
38 | def get_max_account_data_stream_id(self): | |
39 | return self._account_data_id_gen.get_current_token() | |
40 | ||
41 | def process_replication_rows(self, stream_name, instance_name, token, rows): | |
42 | if stream_name == TagAccountDataStream.NAME: | |
43 | self._account_data_id_gen.advance(instance_name, token) | |
44 | for row in rows: | |
45 | self.get_tags_for_user.invalidate((row.user_id,)) | |
46 | self._account_data_stream_cache.entity_has_changed(row.user_id, token) | |
47 | elif stream_name == AccountDataStream.NAME: | |
48 | self._account_data_id_gen.advance(instance_name, token) | |
49 | for row in rows: | |
50 | if not row.room_id: | |
51 | self.get_global_account_data_by_type_for_user.invalidate( | |
52 | (row.data_type, row.user_id) | |
53 | ) | |
54 | self.get_account_data_for_user.invalidate((row.user_id,)) | |
55 | self.get_account_data_for_room.invalidate((row.user_id, row.room_id)) | |
56 | self.get_account_data_for_room_and_type.invalidate( | |
57 | (row.user_id, row.room_id, row.data_type) | |
58 | ) | |
59 | self._account_data_stream_cache.entity_has_changed(row.user_id, token) | |
60 | return super().process_replication_rows(stream_name, instance_name, token, rows) | |
22 | pass |
13 | 13 | # limitations under the License. |
14 | 14 | |
15 | 15 | from synapse.replication.slave.storage._base import BaseSlavedStore |
16 | from synapse.replication.slave.storage._slaved_id_tracker import SlavedIdTracker | |
17 | from synapse.replication.tcp.streams import ToDeviceStream | |
18 | from synapse.storage.database import DatabasePool | |
19 | 16 | from synapse.storage.databases.main.deviceinbox import DeviceInboxWorkerStore |
20 | from synapse.util.caches.expiringcache import ExpiringCache | |
21 | from synapse.util.caches.stream_change_cache import StreamChangeCache | |
22 | 17 | |
23 | 18 | |
24 | 19 | class SlavedDeviceInboxStore(DeviceInboxWorkerStore, BaseSlavedStore): |
25 | def __init__(self, database: DatabasePool, db_conn, hs): | |
26 | super().__init__(database, db_conn, hs) | |
27 | self._device_inbox_id_gen = SlavedIdTracker( | |
28 | db_conn, "device_inbox", "stream_id" | |
29 | ) | |
30 | self._device_inbox_stream_cache = StreamChangeCache( | |
31 | "DeviceInboxStreamChangeCache", | |
32 | self._device_inbox_id_gen.get_current_token(), | |
33 | ) | |
34 | self._device_federation_outbox_stream_cache = StreamChangeCache( | |
35 | "DeviceFederationOutboxStreamChangeCache", | |
36 | self._device_inbox_id_gen.get_current_token(), | |
37 | ) | |
38 | ||
39 | self._last_device_delete_cache = ExpiringCache( | |
40 | cache_name="last_device_delete_cache", | |
41 | clock=self._clock, | |
42 | max_len=10000, | |
43 | expiry_ms=30 * 60 * 1000, | |
44 | ) | |
45 | ||
46 | def process_replication_rows(self, stream_name, instance_name, token, rows): | |
47 | if stream_name == ToDeviceStream.NAME: | |
48 | self._device_inbox_id_gen.advance(instance_name, token) | |
49 | for row in rows: | |
50 | if row.entity.startswith("@"): | |
51 | self._device_inbox_stream_cache.entity_has_changed( | |
52 | row.entity, token | |
53 | ) | |
54 | else: | |
55 | self._device_federation_outbox_stream_cache.entity_has_changed( | |
56 | row.entity, token | |
57 | ) | |
58 | return super().process_replication_rows(stream_name, instance_name, token, rows) | |
20 | pass |
13 | 13 | # See the License for the specific language governing permissions and |
14 | 14 | # limitations under the License. |
15 | 15 | |
16 | from synapse.replication.tcp.streams import ReceiptsStream | |
17 | from synapse.storage.database import DatabasePool | |
18 | 16 | from synapse.storage.databases.main.receipts import ReceiptsWorkerStore |
19 | 17 | |
20 | 18 | from ._base import BaseSlavedStore |
21 | from ._slaved_id_tracker import SlavedIdTracker | |
22 | 19 | |
23 | 20 | |
24 | 21 | class SlavedReceiptsStore(ReceiptsWorkerStore, BaseSlavedStore): |
25 | def __init__(self, database: DatabasePool, db_conn, hs): | |
26 | # We instantiate this first as the ReceiptsWorkerStore constructor | |
27 | # needs to be able to call get_max_receipt_stream_id | |
28 | self._receipts_id_gen = SlavedIdTracker( | |
29 | db_conn, "receipts_linearized", "stream_id" | |
30 | ) | |
31 | ||
32 | super().__init__(database, db_conn, hs) | |
33 | ||
34 | def get_max_receipt_stream_id(self): | |
35 | return self._receipts_id_gen.get_current_token() | |
36 | ||
37 | def invalidate_caches_for_receipt(self, room_id, receipt_type, user_id): | |
38 | self.get_receipts_for_user.invalidate((user_id, receipt_type)) | |
39 | self._get_linearized_receipts_for_room.invalidate_many((room_id,)) | |
40 | self.get_last_receipt_event_id_for_user.invalidate( | |
41 | (user_id, room_id, receipt_type) | |
42 | ) | |
43 | self._invalidate_get_users_with_receipts_in_room(room_id, receipt_type, user_id) | |
44 | self.get_receipts_for_room.invalidate((room_id, receipt_type)) | |
45 | ||
46 | def process_replication_rows(self, stream_name, instance_name, token, rows): | |
47 | if stream_name == ReceiptsStream.NAME: | |
48 | self._receipts_id_gen.advance(instance_name, token) | |
49 | for row in rows: | |
50 | self.invalidate_caches_for_receipt( | |
51 | row.room_id, row.receipt_type, row.user_id | |
52 | ) | |
53 | self._receipts_stream_cache.entity_has_changed(row.room_id, token) | |
54 | ||
55 | return super().process_replication_rows(stream_name, instance_name, token, rows) | |
22 | pass |
50 | 50 | from synapse.replication.tcp.protocol import AbstractConnection |
51 | 51 | from synapse.replication.tcp.streams import ( |
52 | 52 | STREAMS_MAP, |
53 | AccountDataStream, | |
53 | 54 | BackfillStream, |
54 | 55 | CachesStream, |
55 | 56 | EventsStream, |
56 | 57 | FederationStream, |
58 | ReceiptsStream, | |
57 | 59 | Stream, |
60 | TagAccountDataStream, | |
61 | ToDeviceStream, | |
58 | 62 | TypingStream, |
59 | 63 | ) |
60 | 64 | |
114 | 118 | |
115 | 119 | continue |
116 | 120 | |
121 | if isinstance(stream, ToDeviceStream): | |
122 | # Only add ToDeviceStream as a source on instances in charge of | |
123 | # sending to device messages. | |
124 | if hs.get_instance_name() in hs.config.worker.writers.to_device: | |
125 | self._streams_to_replicate.append(stream) | |
126 | ||
127 | continue | |
128 | ||
117 | 129 | if isinstance(stream, TypingStream): |
118 | 130 | # Only add TypingStream as a source on the instance in charge of |
119 | 131 | # typing. |
120 | 132 | if hs.config.worker.writers.typing == hs.get_instance_name(): |
133 | self._streams_to_replicate.append(stream) | |
134 | ||
135 | continue | |
136 | ||
137 | if isinstance(stream, (AccountDataStream, TagAccountDataStream)): | |
138 | # Only add AccountDataStream and TagAccountDataStream as a source on the | |
139 | # instance in charge of account_data persistence. | |
140 | if hs.get_instance_name() in hs.config.worker.writers.account_data: | |
141 | self._streams_to_replicate.append(stream) | |
142 | ||
143 | continue | |
144 | ||
145 | if isinstance(stream, ReceiptsStream): | |
146 | # Only add ReceiptsStream as a source on the instance in charge of | |
147 | # receipts. | |
148 | if hs.get_instance_name() in hs.config.worker.writers.receipts: | |
121 | 149 | self._streams_to_replicate.append(stream) |
122 | 150 | |
123 | 151 | continue |
0 | <html> | |
1 | <head> | |
2 | <title>Authentication Failed</title> | |
3 | </head> | |
4 | <body> | |
5 | <div> | |
6 | <p> | |
7 | We were unable to validate your <tt>{{server_name | e}}</tt> account via | |
8 | single-sign-on (SSO), because the SSO Identity Provider returned | |
9 | different details than when you logged in. | |
10 | </p> | |
11 | <p> | |
12 | Try the operation again, and ensure that you use the same details on | |
13 | the Identity Provider as when you log into your account. | |
14 | </p> | |
15 | </div> | |
16 | </body> | |
17 | </html> |
0 | <!DOCTYPE html> | |
1 | <html lang="en"> | |
2 | <head> | |
3 | <meta charset="UTF-8"> | |
4 | <link rel="stylesheet" href="/_matrix/static/client/login/style.css"> | |
5 | <title>{{server_name | e}} Login</title> | |
6 | </head> | |
7 | <body> | |
8 | <div id="container"> | |
9 | <h1 id="title">{{server_name | e}} Login</h1> | |
10 | <div class="login_flow"> | |
11 | <p>Choose one of the following identity providers:</p> | |
12 | <form> | |
13 | <input type="hidden" name="redirectUrl" value="{{redirect_url | e}}"> | |
14 | <ul class="radiobuttons"> | |
15 | {% for p in providers %} | |
16 | <li> | |
17 | <input type="radio" name="idp" id="prov{{loop.index}}" value="{{p.idp_id}}"> | |
18 | <label for="prov{{loop.index}}">{{p.idp_name | e}}</label> | |
19 | {% if p.idp_icon %} | |
20 | <img src="{{p.idp_icon | mxc_to_http(32, 32)}}"/> | |
21 | {% endif %} | |
22 | </li> | |
23 | {% endfor %} | |
24 | </ul> | |
25 | <input type="submit" class="button button--full-width" id="button-submit" value="Submit"> | |
26 | </form> | |
27 | </div> | |
28 | </div> | |
29 | </body> | |
30 | </html> |
14 | 14 | # limitations under the License. |
15 | 15 | |
16 | 16 | import logging |
17 | from typing import TYPE_CHECKING, Tuple | |
18 | ||
19 | from twisted.web.http import Request | |
17 | 20 | |
18 | 21 | from synapse.api.errors import AuthError, Codes, NotFoundError, SynapseError |
19 | 22 | from synapse.http.servlet import RestServlet, parse_boolean, parse_integer |
22 | 25 | assert_requester_is_admin, |
23 | 26 | assert_user_is_admin, |
24 | 27 | ) |
28 | from synapse.types import JsonDict | |
29 | ||
30 | if TYPE_CHECKING: | |
31 | from synapse.app.homeserver import HomeServer | |
25 | 32 | |
26 | 33 | logger = logging.getLogger(__name__) |
27 | 34 | |
38 | 45 | admin_patterns("/quarantine_media/(?P<room_id>[^/]+)") |
39 | 46 | ) |
40 | 47 | |
41 | def __init__(self, hs): | |
42 | self.store = hs.get_datastore() | |
43 | self.auth = hs.get_auth() | |
44 | ||
45 | async def on_POST(self, request, room_id: str): | |
48 | def __init__(self, hs: "HomeServer"): | |
49 | self.store = hs.get_datastore() | |
50 | self.auth = hs.get_auth() | |
51 | ||
52 | async def on_POST(self, request: Request, room_id: str) -> Tuple[int, JsonDict]: | |
46 | 53 | requester = await self.auth.get_user_by_req(request) |
47 | 54 | await assert_user_is_admin(self.auth, requester.user) |
48 | 55 | |
63 | 70 | |
64 | 71 | PATTERNS = admin_patterns("/user/(?P<user_id>[^/]+)/media/quarantine") |
65 | 72 | |
66 | def __init__(self, hs): | |
67 | self.store = hs.get_datastore() | |
68 | self.auth = hs.get_auth() | |
69 | ||
70 | async def on_POST(self, request, user_id: str): | |
73 | def __init__(self, hs: "HomeServer"): | |
74 | self.store = hs.get_datastore() | |
75 | self.auth = hs.get_auth() | |
76 | ||
77 | async def on_POST(self, request: Request, user_id: str) -> Tuple[int, JsonDict]: | |
71 | 78 | requester = await self.auth.get_user_by_req(request) |
72 | 79 | await assert_user_is_admin(self.auth, requester.user) |
73 | 80 | |
90 | 97 | "/media/quarantine/(?P<server_name>[^/]+)/(?P<media_id>[^/]+)" |
91 | 98 | ) |
92 | 99 | |
93 | def __init__(self, hs): | |
94 | self.store = hs.get_datastore() | |
95 | self.auth = hs.get_auth() | |
96 | ||
97 | async def on_POST(self, request, server_name: str, media_id: str): | |
100 | def __init__(self, hs: "HomeServer"): | |
101 | self.store = hs.get_datastore() | |
102 | self.auth = hs.get_auth() | |
103 | ||
104 | async def on_POST( | |
105 | self, request: Request, server_name: str, media_id: str | |
106 | ) -> Tuple[int, JsonDict]: | |
98 | 107 | requester = await self.auth.get_user_by_req(request) |
99 | 108 | await assert_user_is_admin(self.auth, requester.user) |
100 | 109 | |
108 | 117 | return 200, {} |
109 | 118 | |
110 | 119 | |
120 | class ProtectMediaByID(RestServlet): | |
121 | """Protect local media from being quarantined. | |
122 | """ | |
123 | ||
124 | PATTERNS = admin_patterns("/media/protect/(?P<media_id>[^/]+)") | |
125 | ||
126 | def __init__(self, hs: "HomeServer"): | |
127 | self.store = hs.get_datastore() | |
128 | self.auth = hs.get_auth() | |
129 | ||
130 | async def on_POST(self, request: Request, media_id: str) -> Tuple[int, JsonDict]: | |
131 | requester = await self.auth.get_user_by_req(request) | |
132 | await assert_user_is_admin(self.auth, requester.user) | |
133 | ||
134 | logging.info("Protecting local media by ID: %s", media_id) | |
135 | ||
136 | # Quarantine this media id | |
137 | await self.store.mark_local_media_as_safe(media_id) | |
138 | ||
139 | return 200, {} | |
140 | ||
141 | ||
111 | 142 | class ListMediaInRoom(RestServlet): |
112 | 143 | """Lists all of the media in a given room. |
113 | 144 | """ |
114 | 145 | |
115 | 146 | PATTERNS = admin_patterns("/room/(?P<room_id>[^/]+)/media") |
116 | 147 | |
117 | def __init__(self, hs): | |
118 | self.store = hs.get_datastore() | |
119 | self.auth = hs.get_auth() | |
120 | ||
121 | async def on_GET(self, request, room_id): | |
148 | def __init__(self, hs: "HomeServer"): | |
149 | self.store = hs.get_datastore() | |
150 | self.auth = hs.get_auth() | |
151 | ||
152 | async def on_GET(self, request: Request, room_id: str) -> Tuple[int, JsonDict]: | |
122 | 153 | requester = await self.auth.get_user_by_req(request) |
123 | 154 | is_admin = await self.auth.is_server_admin(requester.user) |
124 | 155 | if not is_admin: |
132 | 163 | class PurgeMediaCacheRestServlet(RestServlet): |
133 | 164 | PATTERNS = admin_patterns("/purge_media_cache") |
134 | 165 | |
135 | def __init__(self, hs): | |
166 | def __init__(self, hs: "HomeServer"): | |
136 | 167 | self.media_repository = hs.get_media_repository() |
137 | 168 | self.auth = hs.get_auth() |
138 | 169 | |
139 | async def on_POST(self, request): | |
170 | async def on_POST(self, request: Request) -> Tuple[int, JsonDict]: | |
140 | 171 | await assert_requester_is_admin(self.auth, request) |
141 | 172 | |
142 | 173 | before_ts = parse_integer(request, "before_ts", required=True) |
153 | 184 | |
154 | 185 | PATTERNS = admin_patterns("/media/(?P<server_name>[^/]+)/(?P<media_id>[^/]+)") |
155 | 186 | |
156 | def __init__(self, hs): | |
187 | def __init__(self, hs: "HomeServer"): | |
157 | 188 | self.store = hs.get_datastore() |
158 | 189 | self.auth = hs.get_auth() |
159 | 190 | self.server_name = hs.hostname |
160 | 191 | self.media_repository = hs.get_media_repository() |
161 | 192 | |
162 | async def on_DELETE(self, request, server_name: str, media_id: str): | |
193 | async def on_DELETE( | |
194 | self, request: Request, server_name: str, media_id: str | |
195 | ) -> Tuple[int, JsonDict]: | |
163 | 196 | await assert_requester_is_admin(self.auth, request) |
164 | 197 | |
165 | 198 | if self.server_name != server_name: |
181 | 214 | |
182 | 215 | PATTERNS = admin_patterns("/media/(?P<server_name>[^/]+)/delete") |
183 | 216 | |
184 | def __init__(self, hs): | |
217 | def __init__(self, hs: "HomeServer"): | |
185 | 218 | self.store = hs.get_datastore() |
186 | 219 | self.auth = hs.get_auth() |
187 | 220 | self.server_name = hs.hostname |
188 | 221 | self.media_repository = hs.get_media_repository() |
189 | 222 | |
190 | async def on_POST(self, request, server_name: str): | |
223 | async def on_POST(self, request: Request, server_name: str) -> Tuple[int, JsonDict]: | |
191 | 224 | await assert_requester_is_admin(self.auth, request) |
192 | 225 | |
193 | 226 | before_ts = parse_integer(request, "before_ts", required=True) |
221 | 254 | return 200, {"deleted_media": deleted_media, "total": total} |
222 | 255 | |
223 | 256 | |
224 | def register_servlets_for_media_repo(hs, http_server): | |
257 | def register_servlets_for_media_repo(hs: "HomeServer", http_server): | |
225 | 258 | """ |
226 | 259 | Media repo specific APIs. |
227 | 260 | """ |
229 | 262 | QuarantineMediaInRoom(hs).register(http_server) |
230 | 263 | QuarantineMediaByID(hs).register(http_server) |
231 | 264 | QuarantineMediaByUser(hs).register(http_server) |
265 | ProtectMediaByID(hs).register(http_server) | |
232 | 266 | ListMediaInRoom(hs).register(http_server) |
233 | 267 | DeleteMediaByID(hs).register(http_server) |
234 | 268 | DeleteMediaByDateSize(hs).register(http_server) |
243 | 243 | |
244 | 244 | if deactivate and not user["deactivated"]: |
245 | 245 | await self.deactivate_account_handler.deactivate_account( |
246 | target_user.to_string(), False | |
246 | target_user.to_string(), False, requester, by_admin=True | |
247 | 247 | ) |
248 | 248 | elif not deactivate and user["deactivated"]: |
249 | 249 | if "password" not in body: |
485 | 485 | class DeactivateAccountRestServlet(RestServlet): |
486 | 486 | PATTERNS = admin_patterns("/deactivate/(?P<target_user_id>[^/]*)") |
487 | 487 | |
488 | def __init__(self, hs): | |
488 | def __init__(self, hs: "HomeServer"): | |
489 | 489 | self._deactivate_account_handler = hs.get_deactivate_account_handler() |
490 | 490 | self.auth = hs.get_auth() |
491 | ||
492 | async def on_POST(self, request, target_user_id): | |
493 | await assert_requester_is_admin(self.auth, request) | |
491 | self.is_mine = hs.is_mine | |
492 | self.store = hs.get_datastore() | |
493 | ||
494 | async def on_POST(self, request: str, target_user_id: str) -> Tuple[int, JsonDict]: | |
495 | requester = await self.auth.get_user_by_req(request) | |
496 | await assert_user_is_admin(self.auth, requester.user) | |
497 | ||
498 | if not self.is_mine(UserID.from_string(target_user_id)): | |
499 | raise SynapseError(400, "Can only deactivate local users") | |
500 | ||
501 | if not await self.store.get_user_by_id(target_user_id): | |
502 | raise NotFoundError("User not found") | |
503 | ||
494 | 504 | body = parse_json_object_from_request(request, allow_empty_body=True) |
495 | 505 | erase = body.get("erase", False) |
496 | 506 | if not isinstance(erase, bool): |
500 | 510 | Codes.BAD_JSON, |
501 | 511 | ) |
502 | 512 | |
503 | UserID.from_string(target_user_id) | |
504 | ||
505 | 513 | result = await self._deactivate_account_handler.deactivate_account( |
506 | target_user_id, erase | |
514 | target_user_id, erase, requester, by_admin=True | |
507 | 515 | ) |
508 | 516 | if result: |
509 | 517 | id_server_unbind_result = "success" |
713 | 721 | async def on_GET(self, request, user_id): |
714 | 722 | await assert_requester_is_admin(self.auth, request) |
715 | 723 | |
716 | if not self.is_mine(UserID.from_string(user_id)): | |
717 | raise SynapseError(400, "Can only lookup local users") | |
718 | ||
719 | user = await self.store.get_user_by_id(user_id) | |
720 | if user is None: | |
721 | raise NotFoundError("Unknown user") | |
722 | ||
723 | 724 | room_ids = await self.store.get_rooms_for_user(user_id) |
724 | 725 | ret = {"joined_rooms": list(room_ids), "total": len(room_ids)} |
725 | 726 | return 200, ret |
310 | 310 | return result |
311 | 311 | |
312 | 312 | |
313 | class BaseSSORedirectServlet(RestServlet): | |
314 | """Common base class for /login/sso/redirect impls""" | |
315 | ||
313 | class SsoRedirectServlet(RestServlet): | |
316 | 314 | PATTERNS = client_patterns("/login/(cas|sso)/redirect", v1=True) |
317 | 315 | |
316 | def __init__(self, hs: "HomeServer"): | |
317 | # make sure that the relevant handlers are instantiated, so that they | |
318 | # register themselves with the main SSOHandler. | |
319 | if hs.config.cas_enabled: | |
320 | hs.get_cas_handler() | |
321 | if hs.config.saml2_enabled: | |
322 | hs.get_saml_handler() | |
323 | if hs.config.oidc_enabled: | |
324 | hs.get_oidc_handler() | |
325 | self._sso_handler = hs.get_sso_handler() | |
326 | ||
318 | 327 | async def on_GET(self, request: SynapseRequest): |
319 | args = request.args | |
320 | if b"redirectUrl" not in args: | |
321 | return 400, "Redirect URL not specified for SSO auth" | |
322 | client_redirect_url = args[b"redirectUrl"][0] | |
323 | sso_url = await self.get_sso_url(request, client_redirect_url) | |
328 | client_redirect_url = parse_string( | |
329 | request, "redirectUrl", required=True, encoding=None | |
330 | ) | |
331 | sso_url = await self._sso_handler.handle_redirect_request( | |
332 | request, client_redirect_url | |
333 | ) | |
334 | logger.info("Redirecting to %s", sso_url) | |
324 | 335 | request.redirect(sso_url) |
325 | 336 | finish_request(request) |
326 | ||
327 | async def get_sso_url( | |
328 | self, request: SynapseRequest, client_redirect_url: bytes | |
329 | ) -> bytes: | |
330 | """Get the URL to redirect to, to perform SSO auth | |
331 | ||
332 | Args: | |
333 | request: The client request to redirect. | |
334 | client_redirect_url: the URL that we should redirect the | |
335 | client to when everything is done | |
336 | ||
337 | Returns: | |
338 | URL to redirect to | |
339 | """ | |
340 | # to be implemented by subclasses | |
341 | raise NotImplementedError() | |
342 | ||
343 | ||
344 | class CasRedirectServlet(BaseSSORedirectServlet): | |
345 | def __init__(self, hs): | |
346 | self._cas_handler = hs.get_cas_handler() | |
347 | ||
348 | async def get_sso_url( | |
349 | self, request: SynapseRequest, client_redirect_url: bytes | |
350 | ) -> bytes: | |
351 | return self._cas_handler.get_redirect_url( | |
352 | {"redirectUrl": client_redirect_url} | |
353 | ).encode("ascii") | |
354 | 337 | |
355 | 338 | |
356 | 339 | class CasTicketServlet(RestServlet): |
378 | 361 | ) |
379 | 362 | |
380 | 363 | |
381 | class SAMLRedirectServlet(BaseSSORedirectServlet): | |
382 | PATTERNS = client_patterns("/login/sso/redirect", v1=True) | |
383 | ||
384 | def __init__(self, hs): | |
385 | self._saml_handler = hs.get_saml_handler() | |
386 | ||
387 | async def get_sso_url( | |
388 | self, request: SynapseRequest, client_redirect_url: bytes | |
389 | ) -> bytes: | |
390 | return self._saml_handler.handle_redirect_request(client_redirect_url) | |
391 | ||
392 | ||
393 | class OIDCRedirectServlet(BaseSSORedirectServlet): | |
394 | """Implementation for /login/sso/redirect for the OIDC login flow.""" | |
395 | ||
396 | PATTERNS = client_patterns("/login/sso/redirect", v1=True) | |
397 | ||
398 | def __init__(self, hs): | |
399 | self._oidc_handler = hs.get_oidc_handler() | |
400 | ||
401 | async def get_sso_url( | |
402 | self, request: SynapseRequest, client_redirect_url: bytes | |
403 | ) -> bytes: | |
404 | return await self._oidc_handler.handle_redirect_request( | |
405 | request, client_redirect_url | |
406 | ) | |
407 | ||
408 | ||
409 | 364 | def register_servlets(hs, http_server): |
410 | 365 | LoginRestServlet(hs).register(http_server) |
366 | SsoRedirectServlet(hs).register(http_server) | |
411 | 367 | if hs.config.cas_enabled: |
412 | CasRedirectServlet(hs).register(http_server) | |
413 | 368 | CasTicketServlet(hs).register(http_server) |
414 | elif hs.config.saml2_enabled: | |
415 | SAMLRedirectServlet(hs).register(http_server) | |
416 | elif hs.config.oidc_enabled: | |
417 | OIDCRedirectServlet(hs).register(http_server) |
45 | 45 | from synapse.streams.config import PaginationConfig |
46 | 46 | from synapse.types import RoomAlias, RoomID, StreamToken, ThirdPartyInstanceID, UserID |
47 | 47 | from synapse.util import json_decoder |
48 | from synapse.util.stringutils import random_string | |
48 | from synapse.util.stringutils import parse_and_validate_server_name, random_string | |
49 | 49 | |
50 | 50 | if TYPE_CHECKING: |
51 | 51 | import synapse.server |
346 | 346 | # provided. |
347 | 347 | if server: |
348 | 348 | raise e |
349 | else: | |
350 | pass | |
351 | 349 | |
352 | 350 | limit = parse_integer(request, "limit", 0) |
353 | 351 | since_token = parse_string(request, "since", None) |
358 | 356 | |
359 | 357 | handler = self.hs.get_room_list_handler() |
360 | 358 | if server and server != self.hs.config.server_name: |
359 | # Ensure the server is valid. | |
360 | try: | |
361 | parse_and_validate_server_name(server) | |
362 | except ValueError: | |
363 | raise SynapseError( | |
364 | 400, "Invalid server name: %s" % (server,), Codes.INVALID_PARAM, | |
365 | ) | |
366 | ||
361 | 367 | try: |
362 | 368 | data = await handler.get_remote_public_room_list( |
363 | 369 | server, limit=limit, since_token=since_token |
401 | 407 | |
402 | 408 | handler = self.hs.get_room_list_handler() |
403 | 409 | if server and server != self.hs.config.server_name: |
410 | # Ensure the server is valid. | |
411 | try: | |
412 | parse_and_validate_server_name(server) | |
413 | except ValueError: | |
414 | raise SynapseError( | |
415 | 400, "Invalid server name: %s" % (server,), Codes.INVALID_PARAM, | |
416 | ) | |
417 | ||
404 | 418 | try: |
405 | 419 | data = await handler.get_remote_public_room_list( |
406 | 420 | server, |
19 | 19 | from typing import TYPE_CHECKING |
20 | 20 | from urllib.parse import urlparse |
21 | 21 | |
22 | if TYPE_CHECKING: | |
23 | from synapse.app.homeserver import HomeServer | |
24 | ||
25 | 22 | from synapse.api.constants import LoginType |
26 | 23 | from synapse.api.errors import ( |
27 | 24 | Codes, |
30 | 27 | ThreepidValidationError, |
31 | 28 | ) |
32 | 29 | from synapse.config.emailconfig import ThreepidBehaviour |
30 | from synapse.handlers.ui_auth import UIAuthSessionDataConstants | |
33 | 31 | from synapse.http.server import finish_request, respond_with_html |
34 | 32 | from synapse.http.servlet import ( |
35 | 33 | RestServlet, |
45 | 43 | |
46 | 44 | from ._base import client_patterns, interactive_auth_handler |
47 | 45 | |
46 | if TYPE_CHECKING: | |
47 | from synapse.app.homeserver import HomeServer | |
48 | ||
49 | ||
48 | 50 | logger = logging.getLogger(__name__) |
49 | 51 | |
50 | 52 | |
188 | 190 | requester = await self.auth.get_user_by_req(request) |
189 | 191 | try: |
190 | 192 | params, session_id = await self.auth_handler.validate_user_via_ui_auth( |
191 | requester, | |
193 | requester, request, body, "modify your account password", | |
194 | ) | |
195 | except InteractiveAuthIncompleteError as e: | |
196 | # The user needs to provide more steps to complete auth, but | |
197 | # they're not required to provide the password again. | |
198 | # | |
199 | # If a password is available now, hash the provided password and | |
200 | # store it for later. | |
201 | if new_password: | |
202 | password_hash = await self.auth_handler.hash(new_password) | |
203 | await self.auth_handler.set_session_data( | |
204 | e.session_id, | |
205 | UIAuthSessionDataConstants.PASSWORD_HASH, | |
206 | password_hash, | |
207 | ) | |
208 | raise | |
209 | user_id = requester.user.to_string() | |
210 | else: | |
211 | requester = None | |
212 | try: | |
213 | result, params, session_id = await self.auth_handler.check_ui_auth( | |
214 | [[LoginType.EMAIL_IDENTITY]], | |
192 | 215 | request, |
193 | 216 | body, |
194 | self.hs.get_ip_from_request(request), | |
195 | 217 | "modify your account password", |
196 | 218 | ) |
197 | 219 | except InteractiveAuthIncompleteError as e: |
203 | 225 | if new_password: |
204 | 226 | password_hash = await self.auth_handler.hash(new_password) |
205 | 227 | await self.auth_handler.set_session_data( |
206 | e.session_id, "password_hash", password_hash | |
207 | ) | |
208 | raise | |
209 | user_id = requester.user.to_string() | |
210 | else: | |
211 | requester = None | |
212 | try: | |
213 | result, params, session_id = await self.auth_handler.check_ui_auth( | |
214 | [[LoginType.EMAIL_IDENTITY]], | |
215 | request, | |
216 | body, | |
217 | self.hs.get_ip_from_request(request), | |
218 | "modify your account password", | |
219 | ) | |
220 | except InteractiveAuthIncompleteError as e: | |
221 | # The user needs to provide more steps to complete auth, but | |
222 | # they're not required to provide the password again. | |
223 | # | |
224 | # If a password is available now, hash the provided password and | |
225 | # store it for later. | |
226 | if new_password: | |
227 | password_hash = await self.auth_handler.hash(new_password) | |
228 | await self.auth_handler.set_session_data( | |
229 | e.session_id, "password_hash", password_hash | |
228 | e.session_id, | |
229 | UIAuthSessionDataConstants.PASSWORD_HASH, | |
230 | password_hash, | |
230 | 231 | ) |
231 | 232 | raise |
232 | 233 | |
259 | 260 | password_hash = await self.auth_handler.hash(new_password) |
260 | 261 | elif session_id is not None: |
261 | 262 | password_hash = await self.auth_handler.get_session_data( |
262 | session_id, "password_hash", None | |
263 | session_id, UIAuthSessionDataConstants.PASSWORD_HASH, None | |
263 | 264 | ) |
264 | 265 | else: |
265 | 266 | # UI validation was skipped, but the request did not include a new |
303 | 304 | # allow ASes to deactivate their own users |
304 | 305 | if requester.app_service: |
305 | 306 | await self._deactivate_account_handler.deactivate_account( |
306 | requester.user.to_string(), erase | |
307 | requester.user.to_string(), erase, requester | |
307 | 308 | ) |
308 | 309 | return 200, {} |
309 | 310 | |
310 | 311 | await self.auth_handler.validate_user_via_ui_auth( |
312 | requester, request, body, "deactivate your account", | |
313 | ) | |
314 | result = await self._deactivate_account_handler.deactivate_account( | |
315 | requester.user.to_string(), | |
316 | erase, | |
311 | 317 | requester, |
312 | request, | |
313 | body, | |
314 | self.hs.get_ip_from_request(request), | |
315 | "deactivate your account", | |
316 | ) | |
317 | result = await self._deactivate_account_handler.deactivate_account( | |
318 | requester.user.to_string(), erase, id_server=body.get("id_server") | |
318 | id_server=body.get("id_server"), | |
319 | 319 | ) |
320 | 320 | if result: |
321 | 321 | id_server_unbind_result = "success" |
694 | 694 | assert_valid_client_secret(client_secret) |
695 | 695 | |
696 | 696 | await self.auth_handler.validate_user_via_ui_auth( |
697 | requester, | |
698 | request, | |
699 | body, | |
700 | self.hs.get_ip_from_request(request), | |
701 | "add a third-party identifier to your account", | |
697 | requester, request, body, "add a third-party identifier to your account", | |
702 | 698 | ) |
703 | 699 | |
704 | 700 | validation_session = await self.identity_handler.validate_threepid_session( |
36 | 36 | super().__init__() |
37 | 37 | self.auth = hs.get_auth() |
38 | 38 | self.store = hs.get_datastore() |
39 | self.notifier = hs.get_notifier() | |
40 | self._is_worker = hs.config.worker_app is not None | |
39 | self.handler = hs.get_account_data_handler() | |
41 | 40 | |
42 | 41 | async def on_PUT(self, request, user_id, account_data_type): |
43 | if self._is_worker: | |
44 | raise Exception("Cannot handle PUT /account_data on worker") | |
45 | ||
46 | 42 | requester = await self.auth.get_user_by_req(request) |
47 | 43 | if user_id != requester.user.to_string(): |
48 | 44 | raise AuthError(403, "Cannot add account data for other users.") |
49 | 45 | |
50 | 46 | body = parse_json_object_from_request(request) |
51 | 47 | |
52 | max_id = await self.store.add_account_data_for_user( | |
53 | user_id, account_data_type, body | |
54 | ) | |
55 | ||
56 | self.notifier.on_new_event("account_data_key", max_id, users=[user_id]) | |
48 | await self.handler.add_account_data_for_user(user_id, account_data_type, body) | |
57 | 49 | |
58 | 50 | return 200, {} |
59 | 51 | |
88 | 80 | super().__init__() |
89 | 81 | self.auth = hs.get_auth() |
90 | 82 | self.store = hs.get_datastore() |
91 | self.notifier = hs.get_notifier() | |
92 | self._is_worker = hs.config.worker_app is not None | |
83 | self.handler = hs.get_account_data_handler() | |
93 | 84 | |
94 | 85 | async def on_PUT(self, request, user_id, room_id, account_data_type): |
95 | if self._is_worker: | |
96 | raise Exception("Cannot handle PUT /account_data on worker") | |
97 | ||
98 | 86 | requester = await self.auth.get_user_by_req(request) |
99 | 87 | if user_id != requester.user.to_string(): |
100 | 88 | raise AuthError(403, "Cannot add account data for other users.") |
108 | 96 | " Use /rooms/!roomId:server.name/read_markers", |
109 | 97 | ) |
110 | 98 | |
111 | max_id = await self.store.add_account_data_to_room( | |
99 | await self.handler.add_account_data_to_room( | |
112 | 100 | user_id, room_id, account_data_type, body |
113 | 101 | ) |
114 | ||
115 | self.notifier.on_new_event("account_data_key", max_id, users=[user_id]) | |
116 | 102 | |
117 | 103 | return 200, {} |
118 | 104 |
13 | 13 | # limitations under the License. |
14 | 14 | |
15 | 15 | import logging |
16 | from typing import TYPE_CHECKING | |
16 | 17 | |
17 | 18 | from synapse.api.constants import LoginType |
18 | 19 | from synapse.api.errors import SynapseError |
21 | 22 | from synapse.http.servlet import RestServlet, parse_string |
22 | 23 | |
23 | 24 | from ._base import client_patterns |
25 | ||
26 | if TYPE_CHECKING: | |
27 | from synapse.server import HomeServer | |
24 | 28 | |
25 | 29 | logger = logging.getLogger(__name__) |
26 | 30 | |
34 | 38 | |
35 | 39 | PATTERNS = client_patterns(r"/auth/(?P<stagetype>[\w\.]*)/fallback/web") |
36 | 40 | |
37 | def __init__(self, hs): | |
41 | def __init__(self, hs: "HomeServer"): | |
38 | 42 | super().__init__() |
39 | 43 | self.hs = hs |
40 | 44 | self.auth = hs.get_auth() |
41 | 45 | self.auth_handler = hs.get_auth_handler() |
42 | 46 | self.registration_handler = hs.get_registration_handler() |
43 | ||
44 | # SSO configuration. | |
45 | self._cas_enabled = hs.config.cas_enabled | |
46 | if self._cas_enabled: | |
47 | self._cas_handler = hs.get_cas_handler() | |
48 | self._cas_server_url = hs.config.cas_server_url | |
49 | self._cas_service_url = hs.config.cas_service_url | |
50 | self._saml_enabled = hs.config.saml2_enabled | |
51 | if self._saml_enabled: | |
52 | self._saml_handler = hs.get_saml_handler() | |
53 | self._oidc_enabled = hs.config.oidc_enabled | |
54 | if self._oidc_enabled: | |
55 | self._oidc_handler = hs.get_oidc_handler() | |
56 | self._cas_server_url = hs.config.cas_server_url | |
57 | self._cas_service_url = hs.config.cas_service_url | |
58 | ||
59 | 47 | self.recaptcha_template = hs.config.recaptcha_template |
60 | 48 | self.terms_template = hs.config.terms_template |
61 | 49 | self.success_template = hs.config.fallback_success_template |
84 | 72 | elif stagetype == LoginType.SSO: |
85 | 73 | # Display a confirmation page which prompts the user to |
86 | 74 | # re-authenticate with their SSO provider. |
87 | if self._cas_enabled: | |
88 | # Generate a request to CAS that redirects back to an endpoint | |
89 | # to verify the successful authentication. | |
90 | sso_redirect_url = self._cas_handler.get_redirect_url( | |
91 | {"session": session}, | |
92 | ) | |
93 | ||
94 | elif self._saml_enabled: | |
95 | # Some SAML identity providers (e.g. Google) require a | |
96 | # RelayState parameter on requests. It is not necessary here, so | |
97 | # pass in a dummy redirect URL (which will never get used). | |
98 | client_redirect_url = b"unused" | |
99 | sso_redirect_url = self._saml_handler.handle_redirect_request( | |
100 | client_redirect_url, session | |
101 | ) | |
102 | ||
103 | elif self._oidc_enabled: | |
104 | client_redirect_url = b"" | |
105 | sso_redirect_url = await self._oidc_handler.handle_redirect_request( | |
106 | request, client_redirect_url, session | |
107 | ) | |
108 | ||
109 | else: | |
110 | raise SynapseError(400, "Homeserver not configured for SSO.") | |
111 | ||
112 | html = await self.auth_handler.start_sso_ui_auth(sso_redirect_url, session) | |
75 | html = await self.auth_handler.start_sso_ui_auth(request, session) | |
113 | 76 | |
114 | 77 | else: |
115 | 78 | raise SynapseError(404, "Unknown auth stage type") |
133 | 96 | authdict = {"response": response, "session": session} |
134 | 97 | |
135 | 98 | success = await self.auth_handler.add_oob_auth( |
136 | LoginType.RECAPTCHA, authdict, self.hs.get_ip_from_request(request) | |
99 | LoginType.RECAPTCHA, authdict, request.getClientIP() | |
137 | 100 | ) |
138 | 101 | |
139 | 102 | if success: |
149 | 112 | authdict = {"session": session} |
150 | 113 | |
151 | 114 | success = await self.auth_handler.add_oob_auth( |
152 | LoginType.TERMS, authdict, self.hs.get_ip_from_request(request) | |
115 | LoginType.TERMS, authdict, request.getClientIP() | |
153 | 116 | ) |
154 | 117 | |
155 | 118 | if success: |
82 | 82 | assert_params_in_dict(body, ["devices"]) |
83 | 83 | |
84 | 84 | await self.auth_handler.validate_user_via_ui_auth( |
85 | requester, | |
86 | request, | |
87 | body, | |
88 | self.hs.get_ip_from_request(request), | |
89 | "remove device(s) from your account", | |
85 | requester, request, body, "remove device(s) from your account", | |
90 | 86 | ) |
91 | 87 | |
92 | 88 | await self.device_handler.delete_devices( |
132 | 128 | raise |
133 | 129 | |
134 | 130 | await self.auth_handler.validate_user_via_ui_auth( |
135 | requester, | |
136 | request, | |
137 | body, | |
138 | self.hs.get_ip_from_request(request), | |
139 | "remove a device from your account", | |
131 | requester, request, body, "remove a device from your account", | |
140 | 132 | ) |
141 | 133 | |
142 | 134 | await self.device_handler.delete_device(requester.user.to_string(), device_id) |
270 | 270 | body = parse_json_object_from_request(request) |
271 | 271 | |
272 | 272 | await self.auth_handler.validate_user_via_ui_auth( |
273 | requester, | |
274 | request, | |
275 | body, | |
276 | self.hs.get_ip_from_request(request), | |
277 | "add a device signing key to your account", | |
273 | requester, request, body, "add a device signing key to your account", | |
278 | 274 | ) |
279 | 275 | |
280 | 276 | result = await self.e2e_keys_handler.upload_signing_keys_for_user(user_id, body) |
37 | 37 | from synapse.config.registration import RegistrationConfig |
38 | 38 | from synapse.config.server import is_threepid_reserved |
39 | 39 | from synapse.handlers.auth import AuthHandler |
40 | from synapse.handlers.ui_auth import UIAuthSessionDataConstants | |
40 | 41 | from synapse.http.server import finish_request, respond_with_html |
41 | 42 | from synapse.http.servlet import ( |
42 | 43 | RestServlet, |
352 | 353 | 403, "Registration has been disabled", errcode=Codes.FORBIDDEN |
353 | 354 | ) |
354 | 355 | |
355 | ip = self.hs.get_ip_from_request(request) | |
356 | ip = request.getClientIP() | |
356 | 357 | with self.ratelimiter.ratelimit(ip) as wait_deferred: |
357 | 358 | await wait_deferred |
358 | 359 | |
493 | 494 | # user here. We carry on and go through the auth checks though, |
494 | 495 | # for paranoia. |
495 | 496 | registered_user_id = await self.auth_handler.get_session_data( |
496 | session_id, "registered_user_id", None | |
497 | session_id, UIAuthSessionDataConstants.REGISTERED_USER_ID, None | |
497 | 498 | ) |
498 | 499 | # Extract the previously-hashed password from the session. |
499 | 500 | password_hash = await self.auth_handler.get_session_data( |
500 | session_id, "password_hash", None | |
501 | session_id, UIAuthSessionDataConstants.PASSWORD_HASH, None | |
501 | 502 | ) |
502 | 503 | |
503 | 504 | # Ensure that the username is valid. |
512 | 513 | # not this will raise a user-interactive auth error. |
513 | 514 | try: |
514 | 515 | auth_result, params, session_id = await self.auth_handler.check_ui_auth( |
515 | self._registration_flows, | |
516 | request, | |
517 | body, | |
518 | self.hs.get_ip_from_request(request), | |
519 | "register a new account", | |
516 | self._registration_flows, request, body, "register a new account", | |
520 | 517 | ) |
521 | 518 | except InteractiveAuthIncompleteError as e: |
522 | 519 | # The user needs to provide more steps to complete auth. |
531 | 528 | if not password_hash and password: |
532 | 529 | password_hash = await self.auth_handler.hash(password) |
533 | 530 | await self.auth_handler.set_session_data( |
534 | e.session_id, "password_hash", password_hash | |
531 | e.session_id, | |
532 | UIAuthSessionDataConstants.PASSWORD_HASH, | |
533 | password_hash, | |
535 | 534 | ) |
536 | 535 | raise |
537 | 536 | |
632 | 631 | # Remember that the user account has been registered (and the user |
633 | 632 | # ID it was registered with, since it might not have been specified). |
634 | 633 | await self.auth_handler.set_session_data( |
635 | session_id, "registered_user_id", registered_user_id | |
634 | session_id, | |
635 | UIAuthSessionDataConstants.REGISTERED_USER_ID, | |
636 | registered_user_id, | |
636 | 637 | ) |
637 | 638 | |
638 | 639 | registered = True |
57 | 57 | def __init__(self, hs): |
58 | 58 | super().__init__() |
59 | 59 | self.auth = hs.get_auth() |
60 | self.store = hs.get_datastore() | |
61 | self.notifier = hs.get_notifier() | |
60 | self.handler = hs.get_account_data_handler() | |
62 | 61 | |
63 | 62 | async def on_PUT(self, request, user_id, room_id, tag): |
64 | 63 | requester = await self.auth.get_user_by_req(request) |
67 | 66 | |
68 | 67 | body = parse_json_object_from_request(request) |
69 | 68 | |
70 | max_id = await self.store.add_tag_to_room(user_id, room_id, tag, body) | |
71 | ||
72 | self.notifier.on_new_event("account_data_key", max_id, users=[user_id]) | |
69 | await self.handler.add_tag_to_room(user_id, room_id, tag, body) | |
73 | 70 | |
74 | 71 | return 200, {} |
75 | 72 | |
78 | 75 | if user_id != requester.user.to_string(): |
79 | 76 | raise AuthError(403, "Cannot add tags for other users.") |
80 | 77 | |
81 | max_id = await self.store.remove_tag_from_room(user_id, room_id, tag) | |
82 | ||
83 | self.notifier.on_new_event("account_data_key", max_id, users=[user_id]) | |
78 | await self.handler.remove_tag_from_room(user_id, room_id, tag) | |
84 | 79 | |
85 | 80 | return 200, {} |
86 | 81 |
0 | 0 | # -*- coding: utf-8 -*- |
1 | 1 | # Copyright 2014-2016 OpenMarket Ltd |
2 | # Copyright 2019 New Vector Ltd | |
2 | # Copyright 2019-2021 The Matrix.org Foundation C.I.C. | |
3 | 3 | # |
4 | 4 | # Licensed under the Apache License, Version 2.0 (the "License"); |
5 | 5 | # you may not use this file except in compliance with the License. |
16 | 16 | import logging |
17 | 17 | import os |
18 | 18 | import urllib |
19 | from typing import Awaitable | |
19 | from typing import Awaitable, Dict, Generator, List, Optional, Tuple | |
20 | 20 | |
21 | 21 | from twisted.internet.interfaces import IConsumer |
22 | 22 | from twisted.protocols.basic import FileSender |
23 | from twisted.web.http import Request | |
23 | 24 | |
24 | 25 | from synapse.api.errors import Codes, SynapseError, cs_error |
25 | 26 | from synapse.http.server import finish_request, respond_with_json |
45 | 46 | ] |
46 | 47 | |
47 | 48 | |
48 | def parse_media_id(request): | |
49 | def parse_media_id(request: Request) -> Tuple[str, str, Optional[str]]: | |
49 | 50 | try: |
50 | 51 | # This allows users to append e.g. /test.png to the URL. Useful for |
51 | 52 | # clients that parse the URL to see content type. |
68 | 69 | ) |
69 | 70 | |
70 | 71 | |
71 | def respond_404(request): | |
72 | def respond_404(request: Request) -> None: | |
72 | 73 | respond_with_json( |
73 | 74 | request, |
74 | 75 | 404, |
78 | 79 | |
79 | 80 | |
80 | 81 | async def respond_with_file( |
81 | request, media_type, file_path, file_size=None, upload_name=None | |
82 | ): | |
82 | request: Request, | |
83 | media_type: str, | |
84 | file_path: str, | |
85 | file_size: Optional[int] = None, | |
86 | upload_name: Optional[str] = None, | |
87 | ) -> None: | |
83 | 88 | logger.debug("Responding with %r", file_path) |
84 | 89 | |
85 | 90 | if os.path.isfile(file_path): |
97 | 102 | respond_404(request) |
98 | 103 | |
99 | 104 | |
100 | def add_file_headers(request, media_type, file_size, upload_name): | |
105 | def add_file_headers( | |
106 | request: Request, | |
107 | media_type: str, | |
108 | file_size: Optional[int], | |
109 | upload_name: Optional[str], | |
110 | ) -> None: | |
101 | 111 | """Adds the correct response headers in preparation for responding with the |
102 | 112 | media. |
103 | 113 | |
104 | 114 | Args: |
105 | request (twisted.web.http.Request) | |
106 | media_type (str): The media/content type. | |
107 | file_size (int): Size in bytes of the media, if known. | |
108 | upload_name (str): The name of the requested file, if any. | |
115 | request | |
116 | media_type: The media/content type. | |
117 | file_size: Size in bytes of the media, if known. | |
118 | upload_name: The name of the requested file, if any. | |
109 | 119 | """ |
110 | 120 | |
111 | 121 | def _quote(x): |
152 | 162 | # select private. don't bother setting Expires as all our |
153 | 163 | # clients are smart enough to be happy with Cache-Control |
154 | 164 | request.setHeader(b"Cache-Control", b"public,max-age=86400,s-maxage=86400") |
155 | request.setHeader(b"Content-Length", b"%d" % (file_size,)) | |
165 | if file_size is not None: | |
166 | request.setHeader(b"Content-Length", b"%d" % (file_size,)) | |
156 | 167 | |
157 | 168 | # Tell web crawlers to not index, archive, or follow links in media. This |
158 | 169 | # should help to prevent things in the media repo from showing up in web |
183 | 194 | } |
184 | 195 | |
185 | 196 | |
186 | def _can_encode_filename_as_token(x): | |
197 | def _can_encode_filename_as_token(x: str) -> bool: | |
187 | 198 | for c in x: |
188 | 199 | # from RFC2616: |
189 | 200 | # |
205 | 216 | |
206 | 217 | |
207 | 218 | async def respond_with_responder( |
208 | request, responder, media_type, file_size, upload_name=None | |
209 | ): | |
219 | request: Request, | |
220 | responder: "Optional[Responder]", | |
221 | media_type: str, | |
222 | file_size: Optional[int], | |
223 | upload_name: Optional[str] = None, | |
224 | ) -> None: | |
210 | 225 | """Responds to the request with given responder. If responder is None then |
211 | 226 | returns 404. |
212 | 227 | |
213 | 228 | Args: |
214 | request (twisted.web.http.Request) | |
215 | responder (Responder|None) | |
216 | media_type (str): The media/content type. | |
217 | file_size (int|None): Size in bytes of the media. If not known it should be None | |
218 | upload_name (str|None): The name of the requested file, if any. | |
229 | request | |
230 | responder | |
231 | media_type: The media/content type. | |
232 | file_size: Size in bytes of the media. If not known it should be None | |
233 | upload_name: The name of the requested file, if any. | |
219 | 234 | """ |
220 | 235 | if request._disconnected: |
221 | 236 | logger.warning( |
307 | 322 | self.thumbnail_type = thumbnail_type |
308 | 323 | |
309 | 324 | |
310 | def get_filename_from_headers(headers): | |
325 | def get_filename_from_headers(headers: Dict[bytes, List[bytes]]) -> Optional[str]: | |
311 | 326 | """ |
312 | 327 | Get the filename of the downloaded file by inspecting the |
313 | 328 | Content-Disposition HTTP header. |
314 | 329 | |
315 | 330 | Args: |
316 | headers (dict[bytes, list[bytes]]): The HTTP request headers. | |
331 | headers: The HTTP request headers. | |
317 | 332 | |
318 | 333 | Returns: |
319 | A Unicode string of the filename, or None. | |
334 | The filename, or None. | |
320 | 335 | """ |
321 | 336 | content_disposition = headers.get(b"Content-Disposition", [b""]) |
322 | 337 | |
323 | 338 | # No header, bail out. |
324 | 339 | if not content_disposition[0]: |
325 | return | |
340 | return None | |
326 | 341 | |
327 | 342 | _, params = _parse_header(content_disposition[0]) |
328 | 343 | |
355 | 370 | return upload_name |
356 | 371 | |
357 | 372 | |
358 | def _parse_header(line): | |
373 | def _parse_header(line: bytes) -> Tuple[bytes, Dict[bytes, bytes]]: | |
359 | 374 | """Parse a Content-type like header. |
360 | 375 | |
361 | 376 | Cargo-culted from `cgi`, but works on bytes rather than strings. |
362 | 377 | |
363 | 378 | Args: |
364 | line (bytes): header to be parsed | |
379 | line: header to be parsed | |
365 | 380 | |
366 | 381 | Returns: |
367 | Tuple[bytes, dict[bytes, bytes]]: | |
368 | the main content-type, followed by the parameter dictionary | |
382 | The main content-type, followed by the parameter dictionary | |
369 | 383 | """ |
370 | 384 | parts = _parseparam(b";" + line) |
371 | 385 | key = next(parts) |
385 | 399 | return key, pdict |
386 | 400 | |
387 | 401 | |
388 | def _parseparam(s): | |
402 | def _parseparam(s: bytes) -> Generator[bytes, None, None]: | |
389 | 403 | """Generator which splits the input on ;, respecting double-quoted sequences |
390 | 404 | |
391 | 405 | Cargo-culted from `cgi`, but works on bytes rather than strings. |
392 | 406 | |
393 | 407 | Args: |
394 | s (bytes): header to be parsed | |
408 | s: header to be parsed | |
395 | 409 | |
396 | 410 | Returns: |
397 | Iterable[bytes]: the split input | |
411 | The split input | |
398 | 412 | """ |
399 | 413 | while s[:1] == b";": |
400 | 414 | s = s[1:] |
0 | 0 | # -*- coding: utf-8 -*- |
1 | 1 | # Copyright 2018 Will Hunt <will@half-shot.uk> |
2 | # Copyright 2020-2021 The Matrix.org Foundation C.I.C. | |
2 | 3 | # |
3 | 4 | # Licensed under the Apache License, Version 2.0 (the "License"); |
4 | 5 | # you may not use this file except in compliance with the License. |
13 | 14 | # limitations under the License. |
14 | 15 | # |
15 | 16 | |
17 | from typing import TYPE_CHECKING | |
18 | ||
19 | from twisted.web.http import Request | |
20 | ||
16 | 21 | from synapse.http.server import DirectServeJsonResource, respond_with_json |
22 | ||
23 | if TYPE_CHECKING: | |
24 | from synapse.app.homeserver import HomeServer | |
17 | 25 | |
18 | 26 | |
19 | 27 | class MediaConfigResource(DirectServeJsonResource): |
20 | 28 | isLeaf = True |
21 | 29 | |
22 | def __init__(self, hs): | |
30 | def __init__(self, hs: "HomeServer"): | |
23 | 31 | super().__init__() |
24 | 32 | config = hs.get_config() |
25 | 33 | self.clock = hs.get_clock() |
26 | 34 | self.auth = hs.get_auth() |
27 | 35 | self.limits_dict = {"m.upload.size": config.max_upload_size} |
28 | 36 | |
29 | async def _async_render_GET(self, request): | |
37 | async def _async_render_GET(self, request: Request) -> None: | |
30 | 38 | await self.auth.get_user_by_req(request) |
31 | 39 | respond_with_json(request, 200, self.limits_dict, send_cors=True) |
32 | 40 | |
33 | async def _async_render_OPTIONS(self, request): | |
41 | async def _async_render_OPTIONS(self, request: Request) -> None: | |
34 | 42 | respond_with_json(request, 200, {}, send_cors=True) |
0 | 0 | # -*- coding: utf-8 -*- |
1 | 1 | # Copyright 2014-2016 OpenMarket Ltd |
2 | # Copyright 2020-2021 The Matrix.org Foundation C.I.C. | |
2 | 3 | # |
3 | 4 | # Licensed under the Apache License, Version 2.0 (the "License"); |
4 | 5 | # you may not use this file except in compliance with the License. |
12 | 13 | # See the License for the specific language governing permissions and |
13 | 14 | # limitations under the License. |
14 | 15 | import logging |
16 | from typing import TYPE_CHECKING | |
15 | 17 | |
16 | import synapse.http.servlet | |
18 | from twisted.web.http import Request | |
19 | ||
17 | 20 | from synapse.http.server import DirectServeJsonResource, set_cors_headers |
21 | from synapse.http.servlet import parse_boolean | |
18 | 22 | |
19 | 23 | from ._base import parse_media_id, respond_404 |
24 | ||
25 | if TYPE_CHECKING: | |
26 | from synapse.app.homeserver import HomeServer | |
27 | from synapse.rest.media.v1.media_repository import MediaRepository | |
20 | 28 | |
21 | 29 | logger = logging.getLogger(__name__) |
22 | 30 | |
24 | 32 | class DownloadResource(DirectServeJsonResource): |
25 | 33 | isLeaf = True |
26 | 34 | |
27 | def __init__(self, hs, media_repo): | |
35 | def __init__(self, hs: "HomeServer", media_repo: "MediaRepository"): | |
28 | 36 | super().__init__() |
29 | 37 | self.media_repo = media_repo |
30 | 38 | self.server_name = hs.hostname |
31 | 39 | |
32 | async def _async_render_GET(self, request): | |
40 | async def _async_render_GET(self, request: Request) -> None: | |
33 | 41 | set_cors_headers(request) |
34 | 42 | request.setHeader( |
35 | 43 | b"Content-Security-Policy", |
48 | 56 | if server_name == self.server_name: |
49 | 57 | await self.media_repo.get_local_media(request, media_id, name) |
50 | 58 | else: |
51 | allow_remote = synapse.http.servlet.parse_boolean( | |
52 | request, "allow_remote", default=True | |
53 | ) | |
59 | allow_remote = parse_boolean(request, "allow_remote", default=True) | |
54 | 60 | if not allow_remote: |
55 | 61 | logger.info( |
56 | 62 | "Rejecting request for remote media %s/%s due to allow_remote", |
0 | 0 | # -*- coding: utf-8 -*- |
1 | 1 | # Copyright 2014-2016 OpenMarket Ltd |
2 | # Copyright 2020-2021 The Matrix.org Foundation C.I.C. | |
2 | 3 | # |
3 | 4 | # Licensed under the Apache License, Version 2.0 (the "License"); |
4 | 5 | # you may not use this file except in compliance with the License. |
15 | 16 | import functools |
16 | 17 | import os |
17 | 18 | import re |
19 | from typing import Callable, List | |
18 | 20 | |
19 | 21 | NEW_FORMAT_ID_RE = re.compile(r"^\d\d\d\d-\d\d-\d\d") |
20 | 22 | |
21 | 23 | |
22 | def _wrap_in_base_path(func): | |
24 | def _wrap_in_base_path(func: "Callable[..., str]") -> "Callable[..., str]": | |
23 | 25 | """Takes a function that returns a relative path and turns it into an |
24 | 26 | absolute path based on the location of the primary media store |
25 | 27 | """ |
40 | 42 | to write to the backup media store (when one is configured) |
41 | 43 | """ |
42 | 44 | |
43 | def __init__(self, primary_base_path): | |
45 | def __init__(self, primary_base_path: str): | |
44 | 46 | self.base_path = primary_base_path |
45 | 47 | |
46 | 48 | def default_thumbnail_rel( |
47 | self, default_top_level, default_sub_type, width, height, content_type, method | |
48 | ): | |
49 | self, | |
50 | default_top_level: str, | |
51 | default_sub_type: str, | |
52 | width: int, | |
53 | height: int, | |
54 | content_type: str, | |
55 | method: str, | |
56 | ) -> str: | |
49 | 57 | top_level_type, sub_type = content_type.split("/") |
50 | 58 | file_name = "%i-%i-%s-%s-%s" % (width, height, top_level_type, sub_type, method) |
51 | 59 | return os.path.join( |
54 | 62 | |
55 | 63 | default_thumbnail = _wrap_in_base_path(default_thumbnail_rel) |
56 | 64 | |
57 | def local_media_filepath_rel(self, media_id): | |
65 | def local_media_filepath_rel(self, media_id: str) -> str: | |
58 | 66 | return os.path.join("local_content", media_id[0:2], media_id[2:4], media_id[4:]) |
59 | 67 | |
60 | 68 | local_media_filepath = _wrap_in_base_path(local_media_filepath_rel) |
61 | 69 | |
62 | def local_media_thumbnail_rel(self, media_id, width, height, content_type, method): | |
70 | def local_media_thumbnail_rel( | |
71 | self, media_id: str, width: int, height: int, content_type: str, method: str | |
72 | ) -> str: | |
63 | 73 | top_level_type, sub_type = content_type.split("/") |
64 | 74 | file_name = "%i-%i-%s-%s-%s" % (width, height, top_level_type, sub_type, method) |
65 | 75 | return os.path.join( |
85 | 95 | media_id[4:], |
86 | 96 | ) |
87 | 97 | |
88 | def remote_media_filepath_rel(self, server_name, file_id): | |
98 | def remote_media_filepath_rel(self, server_name: str, file_id: str) -> str: | |
89 | 99 | return os.path.join( |
90 | 100 | "remote_content", server_name, file_id[0:2], file_id[2:4], file_id[4:] |
91 | 101 | ) |
93 | 103 | remote_media_filepath = _wrap_in_base_path(remote_media_filepath_rel) |
94 | 104 | |
95 | 105 | def remote_media_thumbnail_rel( |
96 | self, server_name, file_id, width, height, content_type, method | |
97 | ): | |
106 | self, | |
107 | server_name: str, | |
108 | file_id: str, | |
109 | width: int, | |
110 | height: int, | |
111 | content_type: str, | |
112 | method: str, | |
113 | ) -> str: | |
98 | 114 | top_level_type, sub_type = content_type.split("/") |
99 | 115 | file_name = "%i-%i-%s-%s-%s" % (width, height, top_level_type, sub_type, method) |
100 | 116 | return os.path.join( |
112 | 128 | # Should be removed after some time, when most of the thumbnails are stored |
113 | 129 | # using the new path. |
114 | 130 | def remote_media_thumbnail_rel_legacy( |
115 | self, server_name, file_id, width, height, content_type | |
131 | self, server_name: str, file_id: str, width: int, height: int, content_type: str | |
116 | 132 | ): |
117 | 133 | top_level_type, sub_type = content_type.split("/") |
118 | 134 | file_name = "%i-%i-%s-%s" % (width, height, top_level_type, sub_type) |
125 | 141 | file_name, |
126 | 142 | ) |
127 | 143 | |
128 | def remote_media_thumbnail_dir(self, server_name, file_id): | |
144 | def remote_media_thumbnail_dir(self, server_name: str, file_id: str) -> str: | |
129 | 145 | return os.path.join( |
130 | 146 | self.base_path, |
131 | 147 | "remote_thumbnail", |
135 | 151 | file_id[4:], |
136 | 152 | ) |
137 | 153 | |
138 | def url_cache_filepath_rel(self, media_id): | |
154 | def url_cache_filepath_rel(self, media_id: str) -> str: | |
139 | 155 | if NEW_FORMAT_ID_RE.match(media_id): |
140 | 156 | # Media id is of the form <DATE><RANDOM_STRING> |
141 | 157 | # E.g.: 2017-09-28-fsdRDt24DS234dsf |
145 | 161 | |
146 | 162 | url_cache_filepath = _wrap_in_base_path(url_cache_filepath_rel) |
147 | 163 | |
148 | def url_cache_filepath_dirs_to_delete(self, media_id): | |
164 | def url_cache_filepath_dirs_to_delete(self, media_id: str) -> List[str]: | |
149 | 165 | "The dirs to try and remove if we delete the media_id file" |
150 | 166 | if NEW_FORMAT_ID_RE.match(media_id): |
151 | 167 | return [os.path.join(self.base_path, "url_cache", media_id[:10])] |
155 | 171 | os.path.join(self.base_path, "url_cache", media_id[0:2]), |
156 | 172 | ] |
157 | 173 | |
158 | def url_cache_thumbnail_rel(self, media_id, width, height, content_type, method): | |
174 | def url_cache_thumbnail_rel( | |
175 | self, media_id: str, width: int, height: int, content_type: str, method: str | |
176 | ) -> str: | |
159 | 177 | # Media id is of the form <DATE><RANDOM_STRING> |
160 | 178 | # E.g.: 2017-09-28-fsdRDt24DS234dsf |
161 | 179 | |
177 | 195 | |
178 | 196 | url_cache_thumbnail = _wrap_in_base_path(url_cache_thumbnail_rel) |
179 | 197 | |
180 | def url_cache_thumbnail_directory(self, media_id): | |
198 | def url_cache_thumbnail_directory(self, media_id: str) -> str: | |
181 | 199 | # Media id is of the form <DATE><RANDOM_STRING> |
182 | 200 | # E.g.: 2017-09-28-fsdRDt24DS234dsf |
183 | 201 | |
194 | 212 | media_id[4:], |
195 | 213 | ) |
196 | 214 | |
197 | def url_cache_thumbnail_dirs_to_delete(self, media_id): | |
215 | def url_cache_thumbnail_dirs_to_delete(self, media_id: str) -> List[str]: | |
198 | 216 | "The dirs to try and remove if we delete the media_id thumbnails" |
199 | 217 | # Media id is of the form <DATE><RANDOM_STRING> |
200 | 218 | # E.g.: 2017-09-28-fsdRDt24DS234dsf |
0 | 0 | # -*- coding: utf-8 -*- |
1 | 1 | # Copyright 2014-2016 OpenMarket Ltd |
2 | # Copyright 2018 New Vector Ltd | |
2 | # Copyright 2018-2021 The Matrix.org Foundation C.I.C. | |
3 | 3 | # |
4 | 4 | # Licensed under the Apache License, Version 2.0 (the "License"); |
5 | 5 | # you may not use this file except in compliance with the License. |
12 | 12 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. |
13 | 13 | # See the License for the specific language governing permissions and |
14 | 14 | # limitations under the License. |
15 | ||
16 | 15 | import errno |
17 | 16 | import logging |
18 | 17 | import os |
19 | 18 | import shutil |
20 | from typing import IO, Dict, List, Optional, Tuple | |
19 | from io import BytesIO | |
20 | from typing import IO, TYPE_CHECKING, Dict, List, Optional, Set, Tuple | |
21 | 21 | |
22 | 22 | import twisted.internet.error |
23 | 23 | import twisted.web.http |
55 | 55 | from .thumbnailer import Thumbnailer, ThumbnailError |
56 | 56 | from .upload_resource import UploadResource |
57 | 57 | |
58 | if TYPE_CHECKING: | |
59 | from synapse.app.homeserver import HomeServer | |
60 | ||
58 | 61 | logger = logging.getLogger(__name__) |
59 | 62 | |
60 | 63 | |
62 | 65 | |
63 | 66 | |
64 | 67 | class MediaRepository: |
65 | def __init__(self, hs): | |
68 | def __init__(self, hs: "HomeServer"): | |
66 | 69 | self.hs = hs |
67 | 70 | self.auth = hs.get_auth() |
68 | 71 | self.client = hs.get_federation_http_client() |
72 | 75 | self.max_upload_size = hs.config.max_upload_size |
73 | 76 | self.max_image_pixels = hs.config.max_image_pixels |
74 | 77 | |
75 | self.primary_base_path = hs.config.media_store_path | |
76 | self.filepaths = MediaFilePaths(self.primary_base_path) | |
78 | self.primary_base_path = hs.config.media_store_path # type: str | |
79 | self.filepaths = MediaFilePaths(self.primary_base_path) # type: MediaFilePaths | |
77 | 80 | |
78 | 81 | self.dynamic_thumbnails = hs.config.dynamic_thumbnails |
79 | 82 | self.thumbnail_requirements = hs.config.thumbnail_requirements |
80 | 83 | |
81 | 84 | self.remote_media_linearizer = Linearizer(name="media_remote") |
82 | 85 | |
83 | self.recently_accessed_remotes = set() | |
84 | self.recently_accessed_locals = set() | |
86 | self.recently_accessed_remotes = set() # type: Set[Tuple[str, str]] | |
87 | self.recently_accessed_locals = set() # type: Set[str] | |
85 | 88 | |
86 | 89 | self.federation_domain_whitelist = hs.config.federation_domain_whitelist |
87 | 90 | |
112 | 115 | "update_recently_accessed_media", self._update_recently_accessed |
113 | 116 | ) |
114 | 117 | |
115 | async def _update_recently_accessed(self): | |
118 | async def _update_recently_accessed(self) -> None: | |
116 | 119 | remote_media = self.recently_accessed_remotes |
117 | 120 | self.recently_accessed_remotes = set() |
118 | 121 | |
123 | 126 | local_media, remote_media, self.clock.time_msec() |
124 | 127 | ) |
125 | 128 | |
126 | def mark_recently_accessed(self, server_name, media_id): | |
129 | def mark_recently_accessed(self, server_name: Optional[str], media_id: str) -> None: | |
127 | 130 | """Mark the given media as recently accessed. |
128 | 131 | |
129 | 132 | Args: |
130 | server_name (str|None): Origin server of media, or None if local | |
131 | media_id (str): The media ID of the content | |
133 | server_name: Origin server of media, or None if local | |
134 | media_id: The media ID of the content | |
132 | 135 | """ |
133 | 136 | if server_name: |
134 | 137 | self.recently_accessed_remotes.add((server_name, media_id)) |
458 | 461 | def _get_thumbnail_requirements(self, media_type): |
459 | 462 | return self.thumbnail_requirements.get(media_type, ()) |
460 | 463 | |
461 | def _generate_thumbnail(self, thumbnailer, t_width, t_height, t_method, t_type): | |
464 | def _generate_thumbnail( | |
465 | self, | |
466 | thumbnailer: Thumbnailer, | |
467 | t_width: int, | |
468 | t_height: int, | |
469 | t_method: str, | |
470 | t_type: str, | |
471 | ) -> Optional[BytesIO]: | |
462 | 472 | m_width = thumbnailer.width |
463 | 473 | m_height = thumbnailer.height |
464 | 474 | |
469 | 479 | m_height, |
470 | 480 | self.max_image_pixels, |
471 | 481 | ) |
472 | return | |
482 | return None | |
473 | 483 | |
474 | 484 | if thumbnailer.transpose_method is not None: |
475 | 485 | m_width, m_height = thumbnailer.transpose() |
476 | 486 | |
477 | 487 | if t_method == "crop": |
478 | t_byte_source = thumbnailer.crop(t_width, t_height, t_type) | |
488 | return thumbnailer.crop(t_width, t_height, t_type) | |
479 | 489 | elif t_method == "scale": |
480 | 490 | t_width, t_height = thumbnailer.aspect(t_width, t_height) |
481 | 491 | t_width = min(m_width, t_width) |
482 | 492 | t_height = min(m_height, t_height) |
483 | t_byte_source = thumbnailer.scale(t_width, t_height, t_type) | |
484 | else: | |
485 | t_byte_source = None | |
486 | ||
487 | return t_byte_source | |
493 | return thumbnailer.scale(t_width, t_height, t_type) | |
494 | ||
495 | return None | |
488 | 496 | |
489 | 497 | async def generate_local_exact_thumbnail( |
490 | 498 | self, |
775 | 783 | |
776 | 784 | return {"width": m_width, "height": m_height} |
777 | 785 | |
778 | async def delete_old_remote_media(self, before_ts): | |
786 | async def delete_old_remote_media(self, before_ts: int) -> Dict[str, int]: | |
779 | 787 | old_media = await self.store.get_remote_media_before(before_ts) |
780 | 788 | |
781 | 789 | deleted = 0 |
927 | 935 | within a given rectangle. |
928 | 936 | """ |
929 | 937 | |
930 | def __init__(self, hs): | |
938 | def __init__(self, hs: "HomeServer"): | |
931 | 939 | # If we're not configured to use it, raise if we somehow got here. |
932 | 940 | if not hs.config.can_load_media_repo: |
933 | 941 | raise ConfigError("Synapse is not configured to use a media repo.") |
0 | 0 | # -*- coding: utf-8 -*- |
1 | # Copyright 2018 New Vecotr Ltd | |
1 | # Copyright 2018-2021 The Matrix.org Foundation C.I.C. | |
2 | 2 | # |
3 | 3 | # Licensed under the Apache License, Version 2.0 (the "License"); |
4 | 4 | # you may not use this file except in compliance with the License. |
17 | 17 | import shutil |
18 | 18 | from typing import IO, TYPE_CHECKING, Any, Optional, Sequence |
19 | 19 | |
20 | from twisted.internet.defer import Deferred | |
21 | from twisted.internet.interfaces import IConsumer | |
20 | 22 | from twisted.protocols.basic import FileSender |
21 | 23 | |
22 | 24 | from synapse.logging.context import defer_to_thread, make_deferred_yieldable |
269 | 271 | return self.filepaths.local_media_filepath_rel(file_info.file_id) |
270 | 272 | |
271 | 273 | |
272 | def _write_file_synchronously(source, dest): | |
274 | def _write_file_synchronously(source: IO, dest: IO) -> None: | |
273 | 275 | """Write `source` to the file like `dest` synchronously. Should be called |
274 | 276 | from a thread. |
275 | 277 | |
285 | 287 | """Wraps an open file that can be sent to a request. |
286 | 288 | |
287 | 289 | Args: |
288 | open_file (file): A file like object to be streamed ot the client, | |
290 | open_file: A file like object to be streamed ot the client, | |
289 | 291 | is closed when finished streaming. |
290 | 292 | """ |
291 | 293 | |
292 | def __init__(self, open_file): | |
294 | def __init__(self, open_file: IO): | |
293 | 295 | self.open_file = open_file |
294 | 296 | |
295 | def write_to_consumer(self, consumer): | |
297 | def write_to_consumer(self, consumer: IConsumer) -> Deferred: | |
296 | 298 | return make_deferred_yieldable( |
297 | 299 | FileSender().beginFileTransfer(self.open_file, consumer) |
298 | 300 | ) |
0 | 0 | # -*- coding: utf-8 -*- |
1 | 1 | # Copyright 2016 OpenMarket Ltd |
2 | # Copyright 2020-2021 The Matrix.org Foundation C.I.C. | |
2 | 3 | # |
3 | 4 | # Licensed under the Apache License, Version 2.0 (the "License"); |
4 | 5 | # you may not use this file except in compliance with the License. |
11 | 12 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. |
12 | 13 | # See the License for the specific language governing permissions and |
13 | 14 | # limitations under the License. |
14 | ||
15 | 15 | import datetime |
16 | 16 | import errno |
17 | 17 | import fnmatch |
22 | 22 | import shutil |
23 | 23 | import sys |
24 | 24 | import traceback |
25 | from typing import Dict, Optional | |
25 | from typing import TYPE_CHECKING, Any, Dict, Generator, Iterable, Optional, Union | |
26 | 26 | from urllib import parse as urlparse |
27 | 27 | |
28 | 28 | import attr |
29 | 29 | |
30 | 30 | from twisted.internet.error import DNSLookupError |
31 | from twisted.web.http import Request | |
31 | 32 | |
32 | 33 | from synapse.api.errors import Codes, SynapseError |
33 | 34 | from synapse.http.client import SimpleHttpClient |
40 | 41 | from synapse.logging.context import make_deferred_yieldable, run_in_background |
41 | 42 | from synapse.metrics.background_process_metrics import run_as_background_process |
42 | 43 | from synapse.rest.media.v1._base import get_filename_from_headers |
44 | from synapse.rest.media.v1.media_storage import MediaStorage | |
43 | 45 | from synapse.util import json_encoder |
44 | 46 | from synapse.util.async_helpers import ObservableDeferred |
45 | 47 | from synapse.util.caches.expiringcache import ExpiringCache |
46 | 48 | from synapse.util.stringutils import random_string |
47 | 49 | |
48 | 50 | from ._base import FileInfo |
51 | ||
52 | if TYPE_CHECKING: | |
53 | from lxml import etree | |
54 | ||
55 | from synapse.app.homeserver import HomeServer | |
56 | from synapse.rest.media.v1.media_repository import MediaRepository | |
49 | 57 | |
50 | 58 | logger = logging.getLogger(__name__) |
51 | 59 | |
118 | 126 | class PreviewUrlResource(DirectServeJsonResource): |
119 | 127 | isLeaf = True |
120 | 128 | |
121 | def __init__(self, hs, media_repo, media_storage): | |
129 | def __init__( | |
130 | self, | |
131 | hs: "HomeServer", | |
132 | media_repo: "MediaRepository", | |
133 | media_storage: MediaStorage, | |
134 | ): | |
122 | 135 | super().__init__() |
123 | 136 | |
124 | 137 | self.auth = hs.get_auth() |
165 | 178 | self._start_expire_url_cache_data, 10 * 1000 |
166 | 179 | ) |
167 | 180 | |
168 | async def _async_render_OPTIONS(self, request): | |
181 | async def _async_render_OPTIONS(self, request: Request) -> None: | |
169 | 182 | request.setHeader(b"Allow", b"OPTIONS, GET") |
170 | 183 | respond_with_json(request, 200, {}, send_cors=True) |
171 | 184 | |
172 | async def _async_render_GET(self, request): | |
185 | async def _async_render_GET(self, request: Request) -> None: | |
173 | 186 | |
174 | 187 | # XXX: if get_user_by_req fails, what should we do in an async render? |
175 | 188 | requester = await self.auth.get_user_by_req(request) |
449 | 462 | logger.warning("Error downloading oEmbed metadata from %s: %r", url, e) |
450 | 463 | raise OEmbedError() from e |
451 | 464 | |
452 | async def _download_url(self, url: str, user): | |
465 | async def _download_url(self, url: str, user: str) -> Dict[str, Any]: | |
453 | 466 | # TODO: we should probably honour robots.txt... except in practice |
454 | 467 | # we're most likely being explicitly triggered by a human rather than a |
455 | 468 | # bot, so are we really a robot? |
579 | 592 | "expire_url_cache_data", self._expire_url_cache_data |
580 | 593 | ) |
581 | 594 | |
582 | async def _expire_url_cache_data(self): | |
595 | async def _expire_url_cache_data(self) -> None: | |
583 | 596 | """Clean up expired url cache content, media and thumbnails. |
584 | 597 | """ |
585 | 598 | # TODO: Delete from backup media store |
675 | 688 | logger.debug("No media removed from url cache") |
676 | 689 | |
677 | 690 | |
678 | def decode_and_calc_og(body, media_uri, request_encoding=None) -> Dict[str, str]: | |
691 | def decode_and_calc_og( | |
692 | body: bytes, media_uri: str, request_encoding: Optional[str] = None | |
693 | ) -> Dict[str, Optional[str]]: | |
679 | 694 | # If there's no body, nothing useful is going to be found. |
680 | 695 | if not body: |
681 | 696 | return {} |
696 | 711 | return og |
697 | 712 | |
698 | 713 | |
699 | def _calc_og(tree, media_uri): | |
714 | def _calc_og(tree, media_uri: str) -> Dict[str, Optional[str]]: | |
700 | 715 | # suck our tree into lxml and define our OG response. |
701 | 716 | |
702 | 717 | # if we see any image URLs in the OG response, then spider them |
800 | 815 | for el in _iterate_over_text(tree.find("body"), *TAGS_TO_REMOVE) |
801 | 816 | ) |
802 | 817 | og["og:description"] = summarize_paragraphs(text_nodes) |
803 | else: | |
818 | elif og["og:description"]: | |
819 | # This must be a non-empty string at this point. | |
820 | assert isinstance(og["og:description"], str) | |
804 | 821 | og["og:description"] = summarize_paragraphs([og["og:description"]]) |
805 | 822 | |
806 | 823 | # TODO: delete the url downloads to stop diskfilling, |
808 | 825 | return og |
809 | 826 | |
810 | 827 | |
811 | def _iterate_over_text(tree, *tags_to_ignore): | |
828 | def _iterate_over_text( | |
829 | tree, *tags_to_ignore: Iterable[Union[str, "etree.Comment"]] | |
830 | ) -> Generator[str, None, None]: | |
812 | 831 | """Iterate over the tree returning text nodes in a depth first fashion, |
813 | 832 | skipping text nodes inside certain tags. |
814 | 833 | """ |
842 | 861 | ) |
843 | 862 | |
844 | 863 | |
845 | def _rebase_url(url, base): | |
846 | base = list(urlparse.urlparse(base)) | |
847 | url = list(urlparse.urlparse(url)) | |
848 | if not url[0]: # fix up schema | |
849 | url[0] = base[0] or "http" | |
850 | if not url[1]: # fix up hostname | |
851 | url[1] = base[1] | |
852 | if not url[2].startswith("/"): | |
853 | url[2] = re.sub(r"/[^/]+$", "/", base[2]) + url[2] | |
854 | return urlparse.urlunparse(url) | |
855 | ||
856 | ||
857 | def _is_media(content_type): | |
858 | if content_type.lower().startswith("image/"): | |
859 | return True | |
860 | ||
861 | ||
862 | def _is_html(content_type): | |
864 | def _rebase_url(url: str, base: str) -> str: | |
865 | base_parts = list(urlparse.urlparse(base)) | |
866 | url_parts = list(urlparse.urlparse(url)) | |
867 | if not url_parts[0]: # fix up schema | |
868 | url_parts[0] = base_parts[0] or "http" | |
869 | if not url_parts[1]: # fix up hostname | |
870 | url_parts[1] = base_parts[1] | |
871 | if not url_parts[2].startswith("/"): | |
872 | url_parts[2] = re.sub(r"/[^/]+$", "/", base_parts[2]) + url_parts[2] | |
873 | return urlparse.urlunparse(url_parts) | |
874 | ||
875 | ||
876 | def _is_media(content_type: str) -> bool: | |
877 | return content_type.lower().startswith("image/") | |
878 | ||
879 | ||
880 | def _is_html(content_type: str) -> bool: | |
863 | 881 | content_type = content_type.lower() |
864 | if content_type.startswith("text/html") or content_type.startswith( | |
882 | return content_type.startswith("text/html") or content_type.startswith( | |
865 | 883 | "application/xhtml" |
866 | ): | |
867 | return True | |
868 | ||
869 | ||
870 | def summarize_paragraphs(text_nodes, min_size=200, max_size=500): | |
884 | ) | |
885 | ||
886 | ||
887 | def summarize_paragraphs( | |
888 | text_nodes: Iterable[str], min_size: int = 200, max_size: int = 500 | |
889 | ) -> Optional[str]: | |
871 | 890 | # Try to get a summary of between 200 and 500 words, respecting |
872 | 891 | # first paragraph and then word boundaries. |
873 | 892 | # TODO: Respect sentences? |
0 | 0 | # -*- coding: utf-8 -*- |
1 | # Copyright 2018 New Vector Ltd | |
1 | # Copyright 2018-2021 The Matrix.org Foundation C.I.C. | |
2 | 2 | # |
3 | 3 | # Licensed under the Apache License, Version 2.0 (the "License"); |
4 | 4 | # you may not use this file except in compliance with the License. |
12 | 12 | # See the License for the specific language governing permissions and |
13 | 13 | # limitations under the License. |
14 | 14 | |
15 | import abc | |
15 | 16 | import logging |
16 | 17 | import os |
17 | 18 | import shutil |
18 | from typing import Optional | |
19 | from typing import TYPE_CHECKING, Optional | |
19 | 20 | |
20 | 21 | from synapse.config._base import Config |
21 | 22 | from synapse.logging.context import defer_to_thread, run_in_background |
26 | 27 | |
27 | 28 | logger = logging.getLogger(__name__) |
28 | 29 | |
30 | if TYPE_CHECKING: | |
31 | from synapse.app.homeserver import HomeServer | |
29 | 32 | |
30 | class StorageProvider: | |
33 | ||
34 | class StorageProvider(metaclass=abc.ABCMeta): | |
31 | 35 | """A storage provider is a service that can store uploaded media and |
32 | 36 | retrieve them. |
33 | 37 | """ |
34 | 38 | |
35 | async def store_file(self, path: str, file_info: FileInfo): | |
39 | @abc.abstractmethod | |
40 | async def store_file(self, path: str, file_info: FileInfo) -> None: | |
36 | 41 | """Store the file described by file_info. The actual contents can be |
37 | 42 | retrieved by reading the file in file_info.upload_path. |
38 | 43 | |
41 | 46 | file_info: The metadata of the file. |
42 | 47 | """ |
43 | 48 | |
49 | @abc.abstractmethod | |
44 | 50 | async def fetch(self, path: str, file_info: FileInfo) -> Optional[Responder]: |
45 | 51 | """Attempt to fetch the file described by file_info and stream it |
46 | 52 | into writer. |
77 | 83 | self.store_synchronous = store_synchronous |
78 | 84 | self.store_remote = store_remote |
79 | 85 | |
80 | def __str__(self): | |
86 | def __str__(self) -> str: | |
81 | 87 | return "StorageProviderWrapper[%s]" % (self.backend,) |
82 | 88 | |
83 | async def store_file(self, path, file_info): | |
89 | async def store_file(self, path: str, file_info: FileInfo) -> None: | |
84 | 90 | if not file_info.server_name and not self.store_local: |
85 | 91 | return None |
86 | 92 | |
90 | 96 | if self.store_synchronous: |
91 | 97 | # store_file is supposed to return an Awaitable, but guard |
92 | 98 | # against improper implementations. |
93 | return await maybe_awaitable(self.backend.store_file(path, file_info)) | |
99 | await maybe_awaitable(self.backend.store_file(path, file_info)) # type: ignore | |
94 | 100 | else: |
95 | 101 | # TODO: Handle errors. |
96 | 102 | async def store(): |
102 | 108 | logger.exception("Error storing file") |
103 | 109 | |
104 | 110 | run_in_background(store) |
105 | return None | |
106 | 111 | |
107 | async def fetch(self, path, file_info): | |
112 | async def fetch(self, path: str, file_info: FileInfo) -> Optional[Responder]: | |
108 | 113 | # store_file is supposed to return an Awaitable, but guard |
109 | 114 | # against improper implementations. |
110 | 115 | return await maybe_awaitable(self.backend.fetch(path, file_info)) |
114 | 119 | """A storage provider that stores files in a directory on a filesystem. |
115 | 120 | |
116 | 121 | Args: |
117 | hs (HomeServer) | |
122 | hs | |
118 | 123 | config: The config returned by `parse_config`. |
119 | 124 | """ |
120 | 125 | |
121 | def __init__(self, hs, config): | |
126 | def __init__(self, hs: "HomeServer", config: str): | |
122 | 127 | self.hs = hs |
123 | 128 | self.cache_directory = hs.config.media_store_path |
124 | 129 | self.base_directory = config |
126 | 131 | def __str__(self): |
127 | 132 | return "FileStorageProviderBackend[%s]" % (self.base_directory,) |
128 | 133 | |
129 | async def store_file(self, path, file_info): | |
134 | async def store_file(self, path: str, file_info: FileInfo) -> None: | |
130 | 135 | """See StorageProvider.store_file""" |
131 | 136 | |
132 | 137 | primary_fname = os.path.join(self.cache_directory, path) |
136 | 141 | if not os.path.exists(dirname): |
137 | 142 | os.makedirs(dirname) |
138 | 143 | |
139 | return await defer_to_thread( | |
144 | await defer_to_thread( | |
140 | 145 | self.hs.get_reactor(), shutil.copyfile, primary_fname, backup_fname |
141 | 146 | ) |
142 | 147 | |
143 | async def fetch(self, path, file_info): | |
148 | async def fetch(self, path: str, file_info: FileInfo) -> Optional[Responder]: | |
144 | 149 | """See StorageProvider.fetch""" |
145 | 150 | |
146 | 151 | backup_fname = os.path.join(self.base_directory, path) |
147 | 152 | if os.path.isfile(backup_fname): |
148 | 153 | return FileResponder(open(backup_fname, "rb")) |
149 | 154 | |
155 | return None | |
156 | ||
150 | 157 | @staticmethod |
151 | def parse_config(config): | |
158 | def parse_config(config: dict) -> str: | |
152 | 159 | """Called on startup to parse config supplied. This should parse |
153 | 160 | the config and raise if there is a problem. |
154 | 161 |
0 | 0 | # -*- coding: utf-8 -*- |
1 | # Copyright 2014 - 2016 OpenMarket Ltd | |
1 | # Copyright 2014-2016 OpenMarket Ltd | |
2 | # Copyright 2020-2021 The Matrix.org Foundation C.I.C. | |
2 | 3 | # |
3 | 4 | # Licensed under the Apache License, Version 2.0 (the "License"); |
4 | 5 | # you may not use this file except in compliance with the License. |
14 | 15 | |
15 | 16 | |
16 | 17 | import logging |
18 | from typing import TYPE_CHECKING | |
19 | ||
20 | from twisted.web.http import Request | |
17 | 21 | |
18 | 22 | from synapse.api.errors import SynapseError |
19 | 23 | from synapse.http.server import DirectServeJsonResource, set_cors_headers |
20 | 24 | from synapse.http.servlet import parse_integer, parse_string |
25 | from synapse.rest.media.v1.media_storage import MediaStorage | |
21 | 26 | |
22 | 27 | from ._base import ( |
23 | 28 | FileInfo, |
27 | 32 | respond_with_responder, |
28 | 33 | ) |
29 | 34 | |
35 | if TYPE_CHECKING: | |
36 | from synapse.app.homeserver import HomeServer | |
37 | from synapse.rest.media.v1.media_repository import MediaRepository | |
38 | ||
30 | 39 | logger = logging.getLogger(__name__) |
31 | 40 | |
32 | 41 | |
33 | 42 | class ThumbnailResource(DirectServeJsonResource): |
34 | 43 | isLeaf = True |
35 | 44 | |
36 | def __init__(self, hs, media_repo, media_storage): | |
45 | def __init__( | |
46 | self, | |
47 | hs: "HomeServer", | |
48 | media_repo: "MediaRepository", | |
49 | media_storage: MediaStorage, | |
50 | ): | |
37 | 51 | super().__init__() |
38 | 52 | |
39 | 53 | self.store = hs.get_datastore() |
42 | 56 | self.dynamic_thumbnails = hs.config.dynamic_thumbnails |
43 | 57 | self.server_name = hs.hostname |
44 | 58 | |
45 | async def _async_render_GET(self, request): | |
59 | async def _async_render_GET(self, request: Request) -> None: | |
46 | 60 | set_cors_headers(request) |
47 | 61 | server_name, media_id, _ = parse_media_id(request) |
48 | 62 | width = parse_integer(request, "width", required=True) |
72 | 86 | self.media_repo.mark_recently_accessed(server_name, media_id) |
73 | 87 | |
74 | 88 | async def _respond_local_thumbnail( |
75 | self, request, media_id, width, height, method, m_type | |
76 | ): | |
89 | self, | |
90 | request: Request, | |
91 | media_id: str, | |
92 | width: int, | |
93 | height: int, | |
94 | method: str, | |
95 | m_type: str, | |
96 | ) -> None: | |
77 | 97 | media_info = await self.store.get_local_media(media_id) |
78 | 98 | |
79 | 99 | if not media_info: |
113 | 133 | |
114 | 134 | async def _select_or_generate_local_thumbnail( |
115 | 135 | self, |
116 | request, | |
117 | media_id, | |
118 | desired_width, | |
119 | desired_height, | |
120 | desired_method, | |
121 | desired_type, | |
122 | ): | |
136 | request: Request, | |
137 | media_id: str, | |
138 | desired_width: int, | |
139 | desired_height: int, | |
140 | desired_method: str, | |
141 | desired_type: str, | |
142 | ) -> None: | |
123 | 143 | media_info = await self.store.get_local_media(media_id) |
124 | 144 | |
125 | 145 | if not media_info: |
177 | 197 | |
178 | 198 | async def _select_or_generate_remote_thumbnail( |
179 | 199 | self, |
180 | request, | |
181 | server_name, | |
182 | media_id, | |
183 | desired_width, | |
184 | desired_height, | |
185 | desired_method, | |
186 | desired_type, | |
187 | ): | |
200 | request: Request, | |
201 | server_name: str, | |
202 | media_id: str, | |
203 | desired_width: int, | |
204 | desired_height: int, | |
205 | desired_method: str, | |
206 | desired_type: str, | |
207 | ) -> None: | |
188 | 208 | media_info = await self.media_repo.get_remote_media_info(server_name, media_id) |
189 | 209 | |
190 | 210 | thumbnail_infos = await self.store.get_remote_media_thumbnails( |
238 | 258 | raise SynapseError(400, "Failed to generate thumbnail.") |
239 | 259 | |
240 | 260 | async def _respond_remote_thumbnail( |
241 | self, request, server_name, media_id, width, height, method, m_type | |
242 | ): | |
261 | self, | |
262 | request: Request, | |
263 | server_name: str, | |
264 | media_id: str, | |
265 | width: int, | |
266 | height: int, | |
267 | method: str, | |
268 | m_type: str, | |
269 | ) -> None: | |
243 | 270 | # TODO: Don't download the whole remote file |
244 | 271 | # We should proxy the thumbnail from the remote server instead of |
245 | 272 | # downloading the remote file and generating our own thumbnails. |
274 | 301 | |
275 | 302 | def _select_thumbnail( |
276 | 303 | self, |
277 | desired_width, | |
278 | desired_height, | |
279 | desired_method, | |
280 | desired_type, | |
304 | desired_width: int, | |
305 | desired_height: int, | |
306 | desired_method: str, | |
307 | desired_type: str, | |
281 | 308 | thumbnail_infos, |
282 | ): | |
309 | ) -> dict: | |
283 | 310 | d_w = desired_width |
284 | 311 | d_h = desired_height |
285 | 312 |
0 | 0 | # -*- coding: utf-8 -*- |
1 | 1 | # Copyright 2014-2016 OpenMarket Ltd |
2 | # Copyright 2020-2021 The Matrix.org Foundation C.I.C. | |
2 | 3 | # |
3 | 4 | # Licensed under the Apache License, Version 2.0 (the "License"); |
4 | 5 | # you may not use this file except in compliance with the License. |
13 | 14 | # limitations under the License. |
14 | 15 | import logging |
15 | 16 | from io import BytesIO |
17 | from typing import Tuple | |
16 | 18 | |
17 | 19 | from PIL import Image |
18 | 20 | |
38 | 40 | |
39 | 41 | FORMATS = {"image/jpeg": "JPEG", "image/png": "PNG"} |
40 | 42 | |
41 | def __init__(self, input_path): | |
43 | def __init__(self, input_path: str): | |
42 | 44 | try: |
43 | 45 | self.image = Image.open(input_path) |
44 | 46 | except OSError as e: |
58 | 60 | # A lot of parsing errors can happen when parsing EXIF |
59 | 61 | logger.info("Error parsing image EXIF information: %s", e) |
60 | 62 | |
61 | def transpose(self): | |
63 | def transpose(self) -> Tuple[int, int]: | |
62 | 64 | """Transpose the image using its EXIF Orientation tag |
63 | 65 | |
64 | 66 | Returns: |
65 | Tuple[int, int]: (width, height) containing the new image size in pixels. | |
67 | A tuple containing the new image size in pixels as (width, height). | |
66 | 68 | """ |
67 | 69 | if self.transpose_method is not None: |
68 | 70 | self.image = self.image.transpose(self.transpose_method) |
72 | 74 | self.image.info["exif"] = None |
73 | 75 | return self.image.size |
74 | 76 | |
75 | def aspect(self, max_width, max_height): | |
77 | def aspect(self, max_width: int, max_height: int) -> Tuple[int, int]: | |
76 | 78 | """Calculate the largest size that preserves aspect ratio which |
77 | 79 | fits within the given rectangle:: |
78 | 80 | |
90 | 92 | else: |
91 | 93 | return (max_height * self.width) // self.height, max_height |
92 | 94 | |
93 | def _resize(self, width, height): | |
95 | def _resize(self, width: int, height: int) -> Image: | |
94 | 96 | # 1-bit or 8-bit color palette images need converting to RGB |
95 | 97 | # otherwise they will be scaled using nearest neighbour which |
96 | 98 | # looks awful |
98 | 100 | self.image = self.image.convert("RGB") |
99 | 101 | return self.image.resize((width, height), Image.ANTIALIAS) |
100 | 102 | |
101 | def scale(self, width, height, output_type): | |
103 | def scale(self, width: int, height: int, output_type: str) -> BytesIO: | |
102 | 104 | """Rescales the image to the given dimensions. |
103 | 105 | |
104 | 106 | Returns: |
107 | 109 | scaled = self._resize(width, height) |
108 | 110 | return self._encode_image(scaled, output_type) |
109 | 111 | |
110 | def crop(self, width, height, output_type): | |
112 | def crop(self, width: int, height: int, output_type: str) -> BytesIO: | |
111 | 113 | """Rescales and crops the image to the given dimensions preserving |
112 | 114 | aspect:: |
113 | 115 | (w_in / h_in) = (w_scaled / h_scaled) |
135 | 137 | cropped = scaled_image.crop((crop_left, 0, crop_right, height)) |
136 | 138 | return self._encode_image(cropped, output_type) |
137 | 139 | |
138 | def _encode_image(self, output_image, output_type): | |
140 | def _encode_image(self, output_image: Image, output_type: str) -> BytesIO: | |
139 | 141 | output_bytes_io = BytesIO() |
140 | 142 | fmt = self.FORMATS[output_type] |
141 | 143 | if fmt == "JPEG": |
0 | 0 | # -*- coding: utf-8 -*- |
1 | 1 | # Copyright 2014-2016 OpenMarket Ltd |
2 | # Copyright 2020-2021 The Matrix.org Foundation C.I.C. | |
2 | 3 | # |
3 | 4 | # Licensed under the Apache License, Version 2.0 (the "License"); |
4 | 5 | # you may not use this file except in compliance with the License. |
13 | 14 | # limitations under the License. |
14 | 15 | |
15 | 16 | import logging |
17 | from typing import TYPE_CHECKING | |
18 | ||
19 | from twisted.web.http import Request | |
16 | 20 | |
17 | 21 | from synapse.api.errors import Codes, SynapseError |
18 | 22 | from synapse.http.server import DirectServeJsonResource, respond_with_json |
19 | 23 | from synapse.http.servlet import parse_string |
24 | ||
25 | if TYPE_CHECKING: | |
26 | from synapse.app.homeserver import HomeServer | |
27 | from synapse.rest.media.v1.media_repository import MediaRepository | |
20 | 28 | |
21 | 29 | logger = logging.getLogger(__name__) |
22 | 30 | |
24 | 32 | class UploadResource(DirectServeJsonResource): |
25 | 33 | isLeaf = True |
26 | 34 | |
27 | def __init__(self, hs, media_repo): | |
35 | def __init__(self, hs: "HomeServer", media_repo: "MediaRepository"): | |
28 | 36 | super().__init__() |
29 | 37 | |
30 | 38 | self.media_repo = media_repo |
36 | 44 | self.max_upload_size = hs.config.max_upload_size |
37 | 45 | self.clock = hs.get_clock() |
38 | 46 | |
39 | async def _async_render_OPTIONS(self, request): | |
47 | async def _async_render_OPTIONS(self, request: Request) -> None: | |
40 | 48 | respond_with_json(request, 200, {}, send_cors=True) |
41 | 49 | |
42 | async def _async_render_POST(self, request): | |
50 | async def _async_render_POST(self, request: Request) -> None: | |
43 | 51 | requester = await self.auth.get_user_by_req(request) |
44 | 52 | # TODO: The checks here are a bit late. The content will have |
45 | 53 | # already been uploaded to a tmp file at this point |
0 | # -*- coding: utf-8 -*- | |
1 | # Copyright 2021 The Matrix.org Foundation C.I.C. | |
2 | # | |
3 | # Licensed under the Apache License, Version 2.0 (the "License"); | |
4 | # you may not use this file except in compliance with the License. | |
5 | # You may obtain a copy of the License at | |
6 | # | |
7 | # http://www.apache.org/licenses/LICENSE-2.0 | |
8 | # | |
9 | # Unless required by applicable law or agreed to in writing, software | |
10 | # distributed under the License is distributed on an "AS IS" BASIS, | |
11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |
12 | # See the License for the specific language governing permissions and | |
13 | # limitations under the License. | |
14 | import logging | |
15 | from typing import TYPE_CHECKING | |
16 | ||
17 | from synapse.http.server import ( | |
18 | DirectServeHtmlResource, | |
19 | finish_request, | |
20 | respond_with_html, | |
21 | ) | |
22 | from synapse.http.servlet import parse_string | |
23 | from synapse.http.site import SynapseRequest | |
24 | ||
25 | if TYPE_CHECKING: | |
26 | from synapse.server import HomeServer | |
27 | ||
28 | logger = logging.getLogger(__name__) | |
29 | ||
30 | ||
31 | class PickIdpResource(DirectServeHtmlResource): | |
32 | """IdP picker resource. | |
33 | ||
34 | This resource gets mounted under /_synapse/client/pick_idp. It serves an HTML page | |
35 | which prompts the user to choose an Identity Provider from the list. | |
36 | """ | |
37 | ||
38 | def __init__(self, hs: "HomeServer"): | |
39 | super().__init__() | |
40 | self._sso_handler = hs.get_sso_handler() | |
41 | self._sso_login_idp_picker_template = ( | |
42 | hs.config.sso.sso_login_idp_picker_template | |
43 | ) | |
44 | self._server_name = hs.hostname | |
45 | ||
46 | async def _async_render_GET(self, request: SynapseRequest) -> None: | |
47 | client_redirect_url = parse_string( | |
48 | request, "redirectUrl", required=True, encoding="utf-8" | |
49 | ) | |
50 | idp = parse_string(request, "idp", required=False) | |
51 | ||
52 | # if we need to pick an IdP, do so | |
53 | if not idp: | |
54 | return await self._serve_id_picker(request, client_redirect_url) | |
55 | ||
56 | # otherwise, redirect to the IdP's redirect URI | |
57 | providers = self._sso_handler.get_identity_providers() | |
58 | auth_provider = providers.get(idp) | |
59 | if not auth_provider: | |
60 | logger.info("Unknown idp %r", idp) | |
61 | self._sso_handler.render_error( | |
62 | request, "unknown_idp", "Unknown identity provider ID" | |
63 | ) | |
64 | return | |
65 | ||
66 | sso_url = await auth_provider.handle_redirect_request( | |
67 | request, client_redirect_url.encode("utf8") | |
68 | ) | |
69 | logger.info("Redirecting to %s", sso_url) | |
70 | request.redirect(sso_url) | |
71 | finish_request(request) | |
72 | ||
73 | async def _serve_id_picker( | |
74 | self, request: SynapseRequest, client_redirect_url: str | |
75 | ) -> None: | |
76 | # otherwise, serve up the IdP picker | |
77 | providers = self._sso_handler.get_identity_providers() | |
78 | html = self._sso_login_idp_picker_template.render( | |
79 | redirect_url=client_redirect_url, | |
80 | server_name=self._server_name, | |
81 | providers=providers.values(), | |
82 | ) | |
83 | respond_with_html(request, 200, html) |
33 | 33 | self._config = hs.config |
34 | 34 | |
35 | 35 | def get_well_known(self): |
36 | # if we don't have a public_baseurl, we can't help much here. | |
37 | if self._config.public_baseurl is None: | |
38 | return None | |
39 | ||
40 | 36 | result = {"m.homeserver": {"base_url": self._config.public_baseurl}} |
41 | 37 | |
42 | 38 | if self._config.default_identity_server: |
54 | 54 | from synapse.federation.transport.client import TransportLayerClient |
55 | 55 | from synapse.groups.attestations import GroupAttestationSigning, GroupAttestionRenewer |
56 | 56 | from synapse.groups.groups_server import GroupsServerHandler, GroupsServerWorkerHandler |
57 | from synapse.handlers.account_data import AccountDataHandler | |
57 | 58 | from synapse.handlers.account_validity import AccountValidityHandler |
58 | 59 | from synapse.handlers.acme import AcmeHandler |
59 | 60 | from synapse.handlers.admin import AdminHandler |
282 | 283 | """ |
283 | 284 | return self._reactor |
284 | 285 | |
285 | def get_ip_from_request(self, request) -> str: | |
286 | # X-Forwarded-For is handled by our custom request type. | |
287 | return request.getClientIP() | |
288 | ||
289 | 286 | def is_mine(self, domain_specific_string: DomainSpecificString) -> bool: |
290 | 287 | return domain_specific_string.domain == self.hostname |
291 | 288 | |
504 | 501 | return InitialSyncHandler(self) |
505 | 502 | |
506 | 503 | @cache_in_self |
507 | def get_profile_handler(self): | |
504 | def get_profile_handler(self) -> ProfileHandler: | |
508 | 505 | return ProfileHandler(self) |
509 | 506 | |
510 | 507 | @cache_in_self |
714 | 711 | def get_module_api(self) -> ModuleApi: |
715 | 712 | return ModuleApi(self, self.get_auth_handler()) |
716 | 713 | |
714 | @cache_in_self | |
715 | def get_account_data_handler(self) -> AccountDataHandler: | |
716 | return AccountDataHandler(self) | |
717 | ||
717 | 718 | async def remove_pusher(self, app_id: str, push_key: str, user_id: str): |
718 | 719 | return await self.get_pusherpool().remove_pusher(app_id, push_key, user_id) |
719 | 720 |
41 | 41 | self._auth = hs.get_auth() |
42 | 42 | self._config = hs.config |
43 | 43 | self._resouce_limited = False |
44 | self._account_data_handler = hs.get_account_data_handler() | |
44 | 45 | self._message_handler = hs.get_message_handler() |
45 | 46 | self._state = hs.get_state_handler() |
46 | 47 | |
176 | 177 | # tag already present, nothing to do here |
177 | 178 | need_to_set_tag = False |
178 | 179 | if need_to_set_tag: |
179 | max_id = await self._store.add_tag_to_room( | |
180 | max_id = await self._account_data_handler.add_tag_to_room( | |
180 | 181 | user_id, room_id, SERVER_NOTICE_ROOM_TAG, {} |
181 | 182 | ) |
182 | 183 | self._notifier.on_new_event("account_data_key", max_id, users=[user_id]) |
34 | 34 | |
35 | 35 | self._store = hs.get_datastore() |
36 | 36 | self._config = hs.config |
37 | self._account_data_handler = hs.get_account_data_handler() | |
37 | 38 | self._room_creation_handler = hs.get_room_creation_handler() |
38 | 39 | self._room_member_handler = hs.get_room_member_handler() |
39 | 40 | self._event_creation_handler = hs.get_event_creation_handler() |
162 | 163 | ) |
163 | 164 | room_id = info["room_id"] |
164 | 165 | |
165 | max_id = await self._store.add_tag_to_room( | |
166 | max_id = await self._account_data_handler.add_tag_to_room( | |
166 | 167 | user_id, room_id, SERVER_NOTICE_ROOM_TAG, {} |
167 | 168 | ) |
168 | 169 | self._notifier.on_new_event("account_data_key", max_id, users=[user_id]) |
28 | 28 | form { |
29 | 29 | text-align: center; |
30 | 30 | margin: 10px 0 0 0; |
31 | } | |
32 | ||
33 | ul.radiobuttons { | |
34 | text-align: left; | |
35 | list-style: none; | |
31 | 36 | } |
32 | 37 | |
33 | 38 | /* |
41 | 41 | from synapse.config.database import DatabaseConnectionConfig |
42 | 42 | from synapse.logging.context import ( |
43 | 43 | LoggingContext, |
44 | LoggingContextOrSentinel, | |
45 | 44 | current_context, |
46 | 45 | make_deferred_yieldable, |
47 | 46 | ) |
49 | 48 | from synapse.storage.background_updates import BackgroundUpdater |
50 | 49 | from synapse.storage.engines import BaseDatabaseEngine, PostgresEngine, Sqlite3Engine |
51 | 50 | from synapse.storage.types import Connection, Cursor |
51 | from synapse.storage.util.sequence import build_sequence_generator | |
52 | 52 | from synapse.types import Collection |
53 | 53 | |
54 | 54 | # python 3 does not have a maximum int value |
179 | 179 | _CallbackListEntry = Tuple["Callable[..., None]", Iterable[Any], Dict[str, Any]] |
180 | 180 | |
181 | 181 | |
182 | R = TypeVar("R") | |
183 | ||
184 | ||
182 | 185 | class LoggingTransaction: |
183 | 186 | """An object that almost-transparently proxies for the 'txn' object |
184 | 187 | passed to the constructor. Adds logging and metrics to the .execute() |
266 | 269 | for val in args: |
267 | 270 | self.execute(sql, val) |
268 | 271 | |
272 | def execute_values(self, sql: str, *args: Any) -> List[Tuple]: | |
273 | """Corresponds to psycopg2.extras.execute_values. Only available when | |
274 | using postgres. | |
275 | ||
276 | Always sets fetch=True when caling `execute_values`, so will return the | |
277 | results. | |
278 | """ | |
279 | assert isinstance(self.database_engine, PostgresEngine) | |
280 | from psycopg2.extras import execute_values # type: ignore | |
281 | ||
282 | return self._do_execute( | |
283 | lambda *x: execute_values(self.txn, *x, fetch=True), sql, *args | |
284 | ) | |
285 | ||
269 | 286 | def execute(self, sql: str, *args: Any) -> None: |
270 | 287 | self._do_execute(self.txn.execute, sql, *args) |
271 | 288 | |
276 | 293 | "Strip newlines out of SQL so that the loggers in the DB are on one line" |
277 | 294 | return " ".join(line.strip() for line in sql.splitlines() if line.strip()) |
278 | 295 | |
279 | def _do_execute(self, func, sql: str, *args: Any) -> None: | |
296 | def _do_execute(self, func: Callable[..., R], sql: str, *args: Any) -> R: | |
280 | 297 | sql = self._make_sql_one_line(sql) |
281 | 298 | |
282 | 299 | # TODO(paul): Maybe use 'info' and 'debug' for values? |
347 | 364 | return top_n_counters |
348 | 365 | |
349 | 366 | |
350 | R = TypeVar("R") | |
351 | ||
352 | ||
353 | 367 | class DatabasePool: |
354 | 368 | """Wraps a single physical database and connection pool. |
355 | 369 | |
397 | 411 | "upsert_safety_check", |
398 | 412 | self._check_safe_to_upsert, |
399 | 413 | ) |
414 | ||
415 | # We define this sequence here so that it can be referenced from both | |
416 | # the DataStore and PersistEventStore. | |
417 | def get_chain_id_txn(txn): | |
418 | txn.execute("SELECT COALESCE(max(chain_id), 0) FROM event_auth_chains") | |
419 | return txn.fetchone()[0] | |
420 | ||
421 | self.event_chain_id_gen = build_sequence_generator( | |
422 | engine, get_chain_id_txn, "event_auth_chain_id" | |
423 | ) | |
400 | 424 | |
401 | 425 | def is_running(self) -> bool: |
402 | 426 | """Is the database pool currently running |
670 | 694 | Returns: |
671 | 695 | The result of func |
672 | 696 | """ |
673 | parent_context = current_context() # type: Optional[LoggingContextOrSentinel] | |
674 | if not parent_context: | |
697 | curr_context = current_context() | |
698 | if not curr_context: | |
675 | 699 | logger.warning( |
676 | 700 | "Starting db connection from sentinel context: metrics will be lost" |
677 | 701 | ) |
678 | 702 | parent_context = None |
703 | else: | |
704 | assert isinstance(curr_context, LoggingContext) | |
705 | parent_context = curr_context | |
679 | 706 | |
680 | 707 | start_time = monotonic_time() |
681 | 708 |
126 | 126 | self._presence_id_gen = StreamIdGenerator( |
127 | 127 | db_conn, "presence_stream", "stream_id" |
128 | 128 | ) |
129 | self._device_inbox_id_gen = StreamIdGenerator( | |
130 | db_conn, "device_inbox", "stream_id" | |
131 | ) | |
132 | 129 | self._public_room_id_gen = StreamIdGenerator( |
133 | 130 | db_conn, "public_room_list_stream", "stream_id" |
134 | 131 | ) |
162 | 159 | database, |
163 | 160 | stream_name="caches", |
164 | 161 | instance_name=hs.get_instance_name(), |
165 | table="cache_invalidation_stream_by_instance", | |
166 | instance_column="instance_name", | |
167 | id_column="stream_id", | |
162 | tables=[ | |
163 | ( | |
164 | "cache_invalidation_stream_by_instance", | |
165 | "instance_name", | |
166 | "stream_id", | |
167 | ) | |
168 | ], | |
168 | 169 | sequence_name="cache_invalidation_stream_seq", |
169 | 170 | writers=[], |
170 | 171 | ) |
186 | 187 | "PresenceStreamChangeCache", |
187 | 188 | min_presence_val, |
188 | 189 | prefilled_cache=presence_cache_prefill, |
189 | ) | |
190 | ||
191 | max_device_inbox_id = self._device_inbox_id_gen.get_current_token() | |
192 | device_inbox_prefill, min_device_inbox_id = self.db_pool.get_cache_dict( | |
193 | db_conn, | |
194 | "device_inbox", | |
195 | entity_column="user_id", | |
196 | stream_column="stream_id", | |
197 | max_value=max_device_inbox_id, | |
198 | limit=1000, | |
199 | ) | |
200 | self._device_inbox_stream_cache = StreamChangeCache( | |
201 | "DeviceInboxStreamChangeCache", | |
202 | min_device_inbox_id, | |
203 | prefilled_cache=device_inbox_prefill, | |
204 | ) | |
205 | # The federation outbox and the local device inbox uses the same | |
206 | # stream_id generator. | |
207 | device_outbox_prefill, min_device_outbox_id = self.db_pool.get_cache_dict( | |
208 | db_conn, | |
209 | "device_federation_outbox", | |
210 | entity_column="destination", | |
211 | stream_column="stream_id", | |
212 | max_value=max_device_inbox_id, | |
213 | limit=1000, | |
214 | ) | |
215 | self._device_federation_outbox_stream_cache = StreamChangeCache( | |
216 | "DeviceFederationOutboxStreamChangeCache", | |
217 | min_device_outbox_id, | |
218 | prefilled_cache=device_outbox_prefill, | |
219 | 190 | ) |
220 | 191 | |
221 | 192 | device_list_max = self._device_list_id_gen.get_current_token() |
13 | 13 | # See the License for the specific language governing permissions and |
14 | 14 | # limitations under the License. |
15 | 15 | |
16 | import abc | |
17 | 16 | import logging |
18 | from typing import Dict, List, Optional, Tuple | |
17 | from typing import Dict, List, Optional, Set, Tuple | |
19 | 18 | |
20 | 19 | from synapse.api.constants import AccountDataTypes |
20 | from synapse.replication.slave.storage._slaved_id_tracker import SlavedIdTracker | |
21 | from synapse.replication.tcp.streams import AccountDataStream, TagAccountDataStream | |
21 | 22 | from synapse.storage._base import SQLBaseStore, db_to_json |
22 | 23 | from synapse.storage.database import DatabasePool |
23 | from synapse.storage.util.id_generators import StreamIdGenerator | |
24 | from synapse.storage.engines import PostgresEngine | |
25 | from synapse.storage.util.id_generators import MultiWriterIdGenerator, StreamIdGenerator | |
24 | 26 | from synapse.types import JsonDict |
25 | 27 | from synapse.util import json_encoder |
26 | from synapse.util.caches.descriptors import _CacheContext, cached | |
28 | from synapse.util.caches.descriptors import cached | |
27 | 29 | from synapse.util.caches.stream_change_cache import StreamChangeCache |
28 | 30 | |
29 | 31 | logger = logging.getLogger(__name__) |
30 | 32 | |
31 | 33 | |
32 | # The ABCMeta metaclass ensures that it cannot be instantiated without | |
33 | # the abstract methods being implemented. | |
34 | class AccountDataWorkerStore(SQLBaseStore, metaclass=abc.ABCMeta): | |
34 | class AccountDataWorkerStore(SQLBaseStore): | |
35 | 35 | """This is an abstract base class where subclasses must implement |
36 | 36 | `get_max_account_data_stream_id` which can be called in the initializer. |
37 | 37 | """ |
38 | 38 | |
39 | 39 | def __init__(self, database: DatabasePool, db_conn, hs): |
40 | self._instance_name = hs.get_instance_name() | |
41 | ||
42 | if isinstance(database.engine, PostgresEngine): | |
43 | self._can_write_to_account_data = ( | |
44 | self._instance_name in hs.config.worker.writers.account_data | |
45 | ) | |
46 | ||
47 | self._account_data_id_gen = MultiWriterIdGenerator( | |
48 | db_conn=db_conn, | |
49 | db=database, | |
50 | stream_name="account_data", | |
51 | instance_name=self._instance_name, | |
52 | tables=[ | |
53 | ("room_account_data", "instance_name", "stream_id"), | |
54 | ("room_tags_revisions", "instance_name", "stream_id"), | |
55 | ("account_data", "instance_name", "stream_id"), | |
56 | ], | |
57 | sequence_name="account_data_sequence", | |
58 | writers=hs.config.worker.writers.account_data, | |
59 | ) | |
60 | else: | |
61 | self._can_write_to_account_data = True | |
62 | ||
63 | # We shouldn't be running in worker mode with SQLite, but its useful | |
64 | # to support it for unit tests. | |
65 | # | |
66 | # If this process is the writer than we need to use | |
67 | # `StreamIdGenerator`, otherwise we use `SlavedIdTracker` which gets | |
68 | # updated over replication. (Multiple writers are not supported for | |
69 | # SQLite). | |
70 | if hs.get_instance_name() in hs.config.worker.writers.account_data: | |
71 | self._account_data_id_gen = StreamIdGenerator( | |
72 | db_conn, | |
73 | "room_account_data", | |
74 | "stream_id", | |
75 | extra_tables=[("room_tags_revisions", "stream_id")], | |
76 | ) | |
77 | else: | |
78 | self._account_data_id_gen = SlavedIdTracker( | |
79 | db_conn, | |
80 | "room_account_data", | |
81 | "stream_id", | |
82 | extra_tables=[("room_tags_revisions", "stream_id")], | |
83 | ) | |
84 | ||
40 | 85 | account_max = self.get_max_account_data_stream_id() |
41 | 86 | self._account_data_stream_cache = StreamChangeCache( |
42 | 87 | "AccountDataAndTagsChangeCache", account_max |
44 | 89 | |
45 | 90 | super().__init__(database, db_conn, hs) |
46 | 91 | |
47 | @abc.abstractmethod | |
48 | def get_max_account_data_stream_id(self): | |
92 | def get_max_account_data_stream_id(self) -> int: | |
49 | 93 | """Get the current max stream ID for account data stream |
50 | 94 | |
51 | 95 | Returns: |
52 | 96 | int |
53 | 97 | """ |
54 | raise NotImplementedError() | |
98 | return self._account_data_id_gen.get_current_token() | |
55 | 99 | |
56 | 100 | @cached() |
57 | 101 | async def get_account_data_for_user( |
286 | 330 | "get_updated_account_data_for_user", get_updated_account_data_for_user_txn |
287 | 331 | ) |
288 | 332 | |
289 | @cached(num_args=2, cache_context=True, max_entries=5000) | |
290 | async def is_ignored_by( | |
291 | self, ignored_user_id: str, ignorer_user_id: str, cache_context: _CacheContext | |
292 | ) -> bool: | |
293 | ignored_account_data = await self.get_global_account_data_by_type_for_user( | |
294 | AccountDataTypes.IGNORED_USER_LIST, | |
295 | ignorer_user_id, | |
296 | on_invalidate=cache_context.invalidate, | |
297 | ) | |
298 | if not ignored_account_data: | |
299 | return False | |
300 | ||
301 | try: | |
302 | return ignored_user_id in ignored_account_data.get("ignored_users", {}) | |
303 | except TypeError: | |
304 | # The type of the ignored_users field is invalid. | |
305 | return False | |
306 | ||
307 | ||
308 | class AccountDataStore(AccountDataWorkerStore): | |
309 | def __init__(self, database: DatabasePool, db_conn, hs): | |
310 | self._account_data_id_gen = StreamIdGenerator( | |
311 | db_conn, | |
312 | "account_data_max_stream_id", | |
313 | "stream_id", | |
314 | extra_tables=[ | |
315 | ("room_account_data", "stream_id"), | |
316 | ("room_tags_revisions", "stream_id"), | |
317 | ], | |
318 | ) | |
319 | ||
320 | super().__init__(database, db_conn, hs) | |
321 | ||
322 | def get_max_account_data_stream_id(self) -> int: | |
323 | """Get the current max stream id for the private user data stream | |
324 | ||
325 | Returns: | |
326 | The maximum stream ID. | |
327 | """ | |
328 | return self._account_data_id_gen.get_current_token() | |
333 | @cached(max_entries=5000, iterable=True) | |
334 | async def ignored_by(self, user_id: str) -> Set[str]: | |
335 | """ | |
336 | Get users which ignore the given user. | |
337 | ||
338 | Params: | |
339 | user_id: The user ID which might be ignored. | |
340 | ||
341 | Return: | |
342 | The user IDs which ignore the given user. | |
343 | """ | |
344 | return set( | |
345 | await self.db_pool.simple_select_onecol( | |
346 | table="ignored_users", | |
347 | keyvalues={"ignored_user_id": user_id}, | |
348 | retcol="ignorer_user_id", | |
349 | desc="ignored_by", | |
350 | ) | |
351 | ) | |
352 | ||
353 | def process_replication_rows(self, stream_name, instance_name, token, rows): | |
354 | if stream_name == TagAccountDataStream.NAME: | |
355 | self._account_data_id_gen.advance(instance_name, token) | |
356 | for row in rows: | |
357 | self.get_tags_for_user.invalidate((row.user_id,)) | |
358 | self._account_data_stream_cache.entity_has_changed(row.user_id, token) | |
359 | elif stream_name == AccountDataStream.NAME: | |
360 | self._account_data_id_gen.advance(instance_name, token) | |
361 | for row in rows: | |
362 | if not row.room_id: | |
363 | self.get_global_account_data_by_type_for_user.invalidate( | |
364 | (row.data_type, row.user_id) | |
365 | ) | |
366 | self.get_account_data_for_user.invalidate((row.user_id,)) | |
367 | self.get_account_data_for_room.invalidate((row.user_id, row.room_id)) | |
368 | self.get_account_data_for_room_and_type.invalidate( | |
369 | (row.user_id, row.room_id, row.data_type) | |
370 | ) | |
371 | self._account_data_stream_cache.entity_has_changed(row.user_id, token) | |
372 | return super().process_replication_rows(stream_name, instance_name, token, rows) | |
329 | 373 | |
330 | 374 | async def add_account_data_to_room( |
331 | 375 | self, user_id: str, room_id: str, account_data_type: str, content: JsonDict |
341 | 385 | Returns: |
342 | 386 | The maximum stream ID. |
343 | 387 | """ |
388 | assert self._can_write_to_account_data | |
389 | ||
344 | 390 | content_json = json_encoder.encode(content) |
345 | 391 | |
346 | 392 | async with self._account_data_id_gen.get_next() as next_id: |
359 | 405 | lock=False, |
360 | 406 | ) |
361 | 407 | |
362 | # it's theoretically possible for the above to succeed and the | |
363 | # below to fail - in which case we might reuse a stream id on | |
364 | # restart, and the above update might not get propagated. That | |
365 | # doesn't sound any worse than the whole update getting lost, | |
366 | # which is what would happen if we combined the two into one | |
367 | # transaction. | |
368 | await self._update_max_stream_id(next_id) | |
369 | ||
370 | 408 | self._account_data_stream_cache.entity_has_changed(user_id, next_id) |
371 | 409 | self.get_account_data_for_user.invalidate((user_id,)) |
372 | 410 | self.get_account_data_for_room.invalidate((user_id, room_id)) |
389 | 427 | Returns: |
390 | 428 | The maximum stream ID. |
391 | 429 | """ |
392 | content_json = json_encoder.encode(content) | |
430 | assert self._can_write_to_account_data | |
393 | 431 | |
394 | 432 | async with self._account_data_id_gen.get_next() as next_id: |
395 | # no need to lock here as account_data has a unique constraint on | |
396 | # (user_id, account_data_type) so simple_upsert will retry if | |
397 | # there is a conflict. | |
398 | await self.db_pool.simple_upsert( | |
399 | desc="add_user_account_data", | |
400 | table="account_data", | |
401 | keyvalues={"user_id": user_id, "account_data_type": account_data_type}, | |
402 | values={"stream_id": next_id, "content": content_json}, | |
403 | lock=False, | |
404 | ) | |
405 | ||
406 | # it's theoretically possible for the above to succeed and the | |
407 | # below to fail - in which case we might reuse a stream id on | |
408 | # restart, and the above update might not get propagated. That | |
409 | # doesn't sound any worse than the whole update getting lost, | |
410 | # which is what would happen if we combined the two into one | |
411 | # transaction. | |
412 | # | |
413 | # Note: This is only here for backwards compat to allow admins to | |
414 | # roll back to a previous Synapse version. Next time we update the | |
415 | # database version we can remove this table. | |
416 | await self._update_max_stream_id(next_id) | |
433 | await self.db_pool.runInteraction( | |
434 | "add_user_account_data", | |
435 | self._add_account_data_for_user, | |
436 | next_id, | |
437 | user_id, | |
438 | account_data_type, | |
439 | content, | |
440 | ) | |
417 | 441 | |
418 | 442 | self._account_data_stream_cache.entity_has_changed(user_id, next_id) |
419 | 443 | self.get_account_data_for_user.invalidate((user_id,)) |
423 | 447 | |
424 | 448 | return self._account_data_id_gen.get_current_token() |
425 | 449 | |
426 | async def _update_max_stream_id(self, next_id: int) -> None: | |
427 | """Update the max stream_id | |
428 | ||
429 | Args: | |
430 | next_id: The the revision to advance to. | |
431 | """ | |
432 | ||
433 | # Note: This is only here for backwards compat to allow admins to | |
434 | # roll back to a previous Synapse version. Next time we update the | |
435 | # database version we can remove this table. | |
436 | ||
437 | def _update(txn): | |
438 | update_max_id_sql = ( | |
439 | "UPDATE account_data_max_stream_id" | |
440 | " SET stream_id = ?" | |
441 | " WHERE stream_id < ?" | |
442 | ) | |
443 | txn.execute(update_max_id_sql, (next_id, next_id)) | |
444 | ||
445 | await self.db_pool.runInteraction("update_account_data_max_stream_id", _update) | |
450 | def _add_account_data_for_user( | |
451 | self, | |
452 | txn, | |
453 | next_id: int, | |
454 | user_id: str, | |
455 | account_data_type: str, | |
456 | content: JsonDict, | |
457 | ) -> None: | |
458 | content_json = json_encoder.encode(content) | |
459 | ||
460 | # no need to lock here as account_data has a unique constraint on | |
461 | # (user_id, account_data_type) so simple_upsert will retry if | |
462 | # there is a conflict. | |
463 | self.db_pool.simple_upsert_txn( | |
464 | txn, | |
465 | table="account_data", | |
466 | keyvalues={"user_id": user_id, "account_data_type": account_data_type}, | |
467 | values={"stream_id": next_id, "content": content_json}, | |
468 | lock=False, | |
469 | ) | |
470 | ||
471 | # Ignored users get denormalized into a separate table as an optimisation. | |
472 | if account_data_type != AccountDataTypes.IGNORED_USER_LIST: | |
473 | return | |
474 | ||
475 | # Insert / delete to sync the list of ignored users. | |
476 | previously_ignored_users = set( | |
477 | self.db_pool.simple_select_onecol_txn( | |
478 | txn, | |
479 | table="ignored_users", | |
480 | keyvalues={"ignorer_user_id": user_id}, | |
481 | retcol="ignored_user_id", | |
482 | ) | |
483 | ) | |
484 | ||
485 | # If the data is invalid, no one is ignored. | |
486 | ignored_users_content = content.get("ignored_users", {}) | |
487 | if isinstance(ignored_users_content, dict): | |
488 | currently_ignored_users = set(ignored_users_content) | |
489 | else: | |
490 | currently_ignored_users = set() | |
491 | ||
492 | # Delete entries which are no longer ignored. | |
493 | self.db_pool.simple_delete_many_txn( | |
494 | txn, | |
495 | table="ignored_users", | |
496 | column="ignored_user_id", | |
497 | iterable=previously_ignored_users - currently_ignored_users, | |
498 | keyvalues={"ignorer_user_id": user_id}, | |
499 | ) | |
500 | ||
501 | # Add entries which are newly ignored. | |
502 | self.db_pool.simple_insert_many_txn( | |
503 | txn, | |
504 | table="ignored_users", | |
505 | values=[ | |
506 | {"ignorer_user_id": user_id, "ignored_user_id": u} | |
507 | for u in currently_ignored_users - previously_ignored_users | |
508 | ], | |
509 | ) | |
510 | ||
511 | # Invalidate the cache for any ignored users which were added or removed. | |
512 | for ignored_user_id in previously_ignored_users ^ currently_ignored_users: | |
513 | self._invalidate_cache_and_stream(txn, self.ignored_by, (ignored_user_id,)) | |
514 | ||
515 | ||
516 | class AccountDataStore(AccountDataWorkerStore): | |
517 | pass |
406 | 406 | "_prune_old_user_ips", _prune_old_user_ips_txn |
407 | 407 | ) |
408 | 408 | |
409 | async def get_last_client_ip_by_device( | |
410 | self, user_id: str, device_id: Optional[str] | |
411 | ) -> Dict[Tuple[str, str], dict]: | |
412 | """For each device_id listed, give the user_ip it was last seen on. | |
413 | ||
414 | The result might be slightly out of date as client IPs are inserted in batches. | |
415 | ||
416 | Args: | |
417 | user_id: The user to fetch devices for. | |
418 | device_id: If None fetches all devices for the user | |
419 | ||
420 | Returns: | |
421 | A dictionary mapping a tuple of (user_id, device_id) to dicts, with | |
422 | keys giving the column names from the devices table. | |
423 | """ | |
424 | ||
425 | keyvalues = {"user_id": user_id} | |
426 | if device_id is not None: | |
427 | keyvalues["device_id"] = device_id | |
428 | ||
429 | res = await self.db_pool.simple_select_list( | |
430 | table="devices", | |
431 | keyvalues=keyvalues, | |
432 | retcols=("user_id", "ip", "user_agent", "device_id", "last_seen"), | |
433 | ) | |
434 | ||
435 | return {(d["user_id"], d["device_id"]): d for d in res} | |
436 | ||
409 | 437 | |
410 | 438 | class ClientIpStore(ClientIpWorkerStore): |
411 | 439 | def __init__(self, database: DatabasePool, db_conn, hs): |
469 | 497 | for entry in to_update.items(): |
470 | 498 | (user_id, access_token, ip), (user_agent, device_id, last_seen) = entry |
471 | 499 | |
472 | try: | |
473 | self.db_pool.simple_upsert_txn( | |
500 | self.db_pool.simple_upsert_txn( | |
501 | txn, | |
502 | table="user_ips", | |
503 | keyvalues={"user_id": user_id, "access_token": access_token, "ip": ip}, | |
504 | values={ | |
505 | "user_agent": user_agent, | |
506 | "device_id": device_id, | |
507 | "last_seen": last_seen, | |
508 | }, | |
509 | lock=False, | |
510 | ) | |
511 | ||
512 | # Technically an access token might not be associated with | |
513 | # a device so we need to check. | |
514 | if device_id: | |
515 | # this is always an update rather than an upsert: the row should | |
516 | # already exist, and if it doesn't, that may be because it has been | |
517 | # deleted, and we don't want to re-create it. | |
518 | self.db_pool.simple_update_txn( | |
474 | 519 | txn, |
475 | table="user_ips", | |
476 | keyvalues={ | |
477 | "user_id": user_id, | |
478 | "access_token": access_token, | |
520 | table="devices", | |
521 | keyvalues={"user_id": user_id, "device_id": device_id}, | |
522 | updatevalues={ | |
523 | "user_agent": user_agent, | |
524 | "last_seen": last_seen, | |
479 | 525 | "ip": ip, |
480 | 526 | }, |
481 | values={ | |
482 | "user_agent": user_agent, | |
483 | "device_id": device_id, | |
484 | "last_seen": last_seen, | |
485 | }, | |
486 | lock=False, | |
487 | 527 | ) |
488 | ||
489 | # Technically an access token might not be associated with | |
490 | # a device so we need to check. | |
491 | if device_id: | |
492 | # this is always an update rather than an upsert: the row should | |
493 | # already exist, and if it doesn't, that may be because it has been | |
494 | # deleted, and we don't want to re-create it. | |
495 | self.db_pool.simple_update_txn( | |
496 | txn, | |
497 | table="devices", | |
498 | keyvalues={"user_id": user_id, "device_id": device_id}, | |
499 | updatevalues={ | |
500 | "user_agent": user_agent, | |
501 | "last_seen": last_seen, | |
502 | "ip": ip, | |
503 | }, | |
504 | ) | |
505 | except Exception as e: | |
506 | # Failed to upsert, log and continue | |
507 | logger.error("Failed to insert client IP %r: %r", entry, e) | |
508 | 528 | |
509 | 529 | async def get_last_client_ip_by_device( |
510 | 530 | self, user_id: str, device_id: Optional[str] |
519 | 539 | A dictionary mapping a tuple of (user_id, device_id) to dicts, with |
520 | 540 | keys giving the column names from the devices table. |
521 | 541 | """ |
522 | ||
523 | keyvalues = {"user_id": user_id} | |
524 | if device_id is not None: | |
525 | keyvalues["device_id"] = device_id | |
526 | ||
527 | res = await self.db_pool.simple_select_list( | |
528 | table="devices", | |
529 | keyvalues=keyvalues, | |
530 | retcols=("user_id", "ip", "user_agent", "device_id", "last_seen"), | |
531 | ) | |
532 | ||
533 | ret = {(d["user_id"], d["device_id"]): d for d in res} | |
542 | ret = await super().get_last_client_ip_by_device(user_id, device_id) | |
543 | ||
544 | # Update what is retrieved from the database with data which is pending insertion. | |
534 | 545 | for key in self._batch_row_update: |
535 | 546 | uid, access_token, ip = key |
536 | 547 | if uid == user_id: |
16 | 16 | from typing import List, Tuple |
17 | 17 | |
18 | 18 | from synapse.logging.opentracing import log_kv, set_tag, trace |
19 | from synapse.storage._base import SQLBaseStore, db_to_json, make_in_list_sql_clause | |
19 | from synapse.replication.tcp.streams import ToDeviceStream | |
20 | from synapse.storage._base import SQLBaseStore, db_to_json | |
20 | 21 | from synapse.storage.database import DatabasePool |
22 | from synapse.storage.engines import PostgresEngine | |
23 | from synapse.storage.util.id_generators import MultiWriterIdGenerator, StreamIdGenerator | |
21 | 24 | from synapse.util import json_encoder |
22 | 25 | from synapse.util.caches.expiringcache import ExpiringCache |
26 | from synapse.util.caches.stream_change_cache import StreamChangeCache | |
23 | 27 | |
24 | 28 | logger = logging.getLogger(__name__) |
25 | 29 | |
26 | 30 | |
27 | 31 | class DeviceInboxWorkerStore(SQLBaseStore): |
32 | def __init__(self, database: DatabasePool, db_conn, hs): | |
33 | super().__init__(database, db_conn, hs) | |
34 | ||
35 | self._instance_name = hs.get_instance_name() | |
36 | ||
37 | # Map of (user_id, device_id) to the last stream_id that has been | |
38 | # deleted up to. This is so that we can no op deletions. | |
39 | self._last_device_delete_cache = ExpiringCache( | |
40 | cache_name="last_device_delete_cache", | |
41 | clock=self._clock, | |
42 | max_len=10000, | |
43 | expiry_ms=30 * 60 * 1000, | |
44 | ) | |
45 | ||
46 | if isinstance(database.engine, PostgresEngine): | |
47 | self._can_write_to_device = ( | |
48 | self._instance_name in hs.config.worker.writers.to_device | |
49 | ) | |
50 | ||
51 | self._device_inbox_id_gen = MultiWriterIdGenerator( | |
52 | db_conn=db_conn, | |
53 | db=database, | |
54 | stream_name="to_device", | |
55 | instance_name=self._instance_name, | |
56 | tables=[("device_inbox", "instance_name", "stream_id")], | |
57 | sequence_name="device_inbox_sequence", | |
58 | writers=hs.config.worker.writers.to_device, | |
59 | ) | |
60 | else: | |
61 | self._can_write_to_device = True | |
62 | self._device_inbox_id_gen = StreamIdGenerator( | |
63 | db_conn, "device_inbox", "stream_id" | |
64 | ) | |
65 | ||
66 | max_device_inbox_id = self._device_inbox_id_gen.get_current_token() | |
67 | device_inbox_prefill, min_device_inbox_id = self.db_pool.get_cache_dict( | |
68 | db_conn, | |
69 | "device_inbox", | |
70 | entity_column="user_id", | |
71 | stream_column="stream_id", | |
72 | max_value=max_device_inbox_id, | |
73 | limit=1000, | |
74 | ) | |
75 | self._device_inbox_stream_cache = StreamChangeCache( | |
76 | "DeviceInboxStreamChangeCache", | |
77 | min_device_inbox_id, | |
78 | prefilled_cache=device_inbox_prefill, | |
79 | ) | |
80 | ||
81 | # The federation outbox and the local device inbox uses the same | |
82 | # stream_id generator. | |
83 | device_outbox_prefill, min_device_outbox_id = self.db_pool.get_cache_dict( | |
84 | db_conn, | |
85 | "device_federation_outbox", | |
86 | entity_column="destination", | |
87 | stream_column="stream_id", | |
88 | max_value=max_device_inbox_id, | |
89 | limit=1000, | |
90 | ) | |
91 | self._device_federation_outbox_stream_cache = StreamChangeCache( | |
92 | "DeviceFederationOutboxStreamChangeCache", | |
93 | min_device_outbox_id, | |
94 | prefilled_cache=device_outbox_prefill, | |
95 | ) | |
96 | ||
97 | def process_replication_rows(self, stream_name, instance_name, token, rows): | |
98 | if stream_name == ToDeviceStream.NAME: | |
99 | self._device_inbox_id_gen.advance(instance_name, token) | |
100 | for row in rows: | |
101 | if row.entity.startswith("@"): | |
102 | self._device_inbox_stream_cache.entity_has_changed( | |
103 | row.entity, token | |
104 | ) | |
105 | else: | |
106 | self._device_federation_outbox_stream_cache.entity_has_changed( | |
107 | row.entity, token | |
108 | ) | |
109 | return super().process_replication_rows(stream_name, instance_name, token, rows) | |
110 | ||
28 | 111 | def get_to_device_stream_token(self): |
29 | 112 | return self._device_inbox_id_gen.get_current_token() |
30 | 113 | |
277 | 360 | "get_all_new_device_messages", get_all_new_device_messages_txn |
278 | 361 | ) |
279 | 362 | |
280 | ||
281 | class DeviceInboxBackgroundUpdateStore(SQLBaseStore): | |
282 | DEVICE_INBOX_STREAM_ID = "device_inbox_stream_drop" | |
283 | ||
284 | def __init__(self, database: DatabasePool, db_conn, hs): | |
285 | super().__init__(database, db_conn, hs) | |
286 | ||
287 | self.db_pool.updates.register_background_index_update( | |
288 | "device_inbox_stream_index", | |
289 | index_name="device_inbox_stream_id_user_id", | |
290 | table="device_inbox", | |
291 | columns=["stream_id", "user_id"], | |
292 | ) | |
293 | ||
294 | self.db_pool.updates.register_background_update_handler( | |
295 | self.DEVICE_INBOX_STREAM_ID, self._background_drop_index_device_inbox | |
296 | ) | |
297 | ||
298 | async def _background_drop_index_device_inbox(self, progress, batch_size): | |
299 | def reindex_txn(conn): | |
300 | txn = conn.cursor() | |
301 | txn.execute("DROP INDEX IF EXISTS device_inbox_stream_id") | |
302 | txn.close() | |
303 | ||
304 | await self.db_pool.runWithConnection(reindex_txn) | |
305 | ||
306 | await self.db_pool.updates._end_background_update(self.DEVICE_INBOX_STREAM_ID) | |
307 | ||
308 | return 1 | |
309 | ||
310 | ||
311 | class DeviceInboxStore(DeviceInboxWorkerStore, DeviceInboxBackgroundUpdateStore): | |
312 | DEVICE_INBOX_STREAM_ID = "device_inbox_stream_drop" | |
313 | ||
314 | def __init__(self, database: DatabasePool, db_conn, hs): | |
315 | super().__init__(database, db_conn, hs) | |
316 | ||
317 | # Map of (user_id, device_id) to the last stream_id that has been | |
318 | # deleted up to. This is so that we can no op deletions. | |
319 | self._last_device_delete_cache = ExpiringCache( | |
320 | cache_name="last_device_delete_cache", | |
321 | clock=self._clock, | |
322 | max_len=10000, | |
323 | expiry_ms=30 * 60 * 1000, | |
324 | ) | |
325 | ||
326 | 363 | @trace |
327 | 364 | async def add_messages_to_device_inbox( |
328 | 365 | self, |
341 | 378 | The new stream_id. |
342 | 379 | """ |
343 | 380 | |
381 | assert self._can_write_to_device | |
382 | ||
344 | 383 | def add_messages_txn(txn, now_ms, stream_id): |
345 | 384 | # Add the local messages directly to the local inbox. |
346 | 385 | self._add_messages_to_local_device_inbox_txn( |
350 | 389 | # Add the remote messages to the federation outbox. |
351 | 390 | # We'll send them to a remote server when we next send a |
352 | 391 | # federation transaction to that destination. |
353 | sql = ( | |
354 | "INSERT INTO device_federation_outbox" | |
355 | " (destination, stream_id, queued_ts, messages_json)" | |
356 | " VALUES (?,?,?,?)" | |
357 | ) | |
358 | rows = [] | |
359 | for destination, edu in remote_messages_by_destination.items(): | |
360 | edu_json = json_encoder.encode(edu) | |
361 | rows.append((destination, stream_id, now_ms, edu_json)) | |
362 | txn.executemany(sql, rows) | |
392 | self.db_pool.simple_insert_many_txn( | |
393 | txn, | |
394 | table="device_federation_outbox", | |
395 | values=[ | |
396 | { | |
397 | "destination": destination, | |
398 | "stream_id": stream_id, | |
399 | "queued_ts": now_ms, | |
400 | "messages_json": json_encoder.encode(edu), | |
401 | "instance_name": self._instance_name, | |
402 | } | |
403 | for destination, edu in remote_messages_by_destination.items() | |
404 | ], | |
405 | ) | |
363 | 406 | |
364 | 407 | async with self._device_inbox_id_gen.get_next() as stream_id: |
365 | 408 | now_ms = self.clock.time_msec() |
378 | 421 | async def add_messages_from_remote_to_device_inbox( |
379 | 422 | self, origin: str, message_id: str, local_messages_by_user_then_device: dict |
380 | 423 | ) -> int: |
424 | assert self._can_write_to_device | |
425 | ||
381 | 426 | def add_messages_txn(txn, now_ms, stream_id): |
382 | 427 | # Check if we've already inserted a matching message_id for that |
383 | 428 | # origin. This can happen if the origin doesn't receive our |
426 | 471 | def _add_messages_to_local_device_inbox_txn( |
427 | 472 | self, txn, stream_id, messages_by_user_then_device |
428 | 473 | ): |
474 | assert self._can_write_to_device | |
475 | ||
429 | 476 | local_by_user_then_device = {} |
430 | 477 | for user_id, messages_by_device in messages_by_user_then_device.items(): |
431 | 478 | messages_json_for_user = {} |
432 | 479 | devices = list(messages_by_device.keys()) |
433 | 480 | if len(devices) == 1 and devices[0] == "*": |
434 | 481 | # Handle wildcard device_ids. |
435 | sql = "SELECT device_id FROM devices WHERE user_id = ?" | |
436 | txn.execute(sql, (user_id,)) | |
482 | devices = self.db_pool.simple_select_onecol_txn( | |
483 | txn, | |
484 | table="devices", | |
485 | keyvalues={"user_id": user_id}, | |
486 | retcol="device_id", | |
487 | ) | |
488 | ||
437 | 489 | message_json = json_encoder.encode(messages_by_device["*"]) |
438 | for row in txn: | |
490 | for device_id in devices: | |
439 | 491 | # Add the message for all devices for this user on this |
440 | 492 | # server. |
441 | device = row[0] | |
442 | messages_json_for_user[device] = message_json | |
493 | messages_json_for_user[device_id] = message_json | |
443 | 494 | else: |
444 | 495 | if not devices: |
445 | 496 | continue |
446 | 497 | |
447 | clause, args = make_in_list_sql_clause( | |
448 | txn.database_engine, "device_id", devices | |
498 | rows = self.db_pool.simple_select_many_txn( | |
499 | txn, | |
500 | table="devices", | |
501 | keyvalues={"user_id": user_id}, | |
502 | column="device_id", | |
503 | iterable=devices, | |
504 | retcols=("device_id",), | |
449 | 505 | ) |
450 | sql = "SELECT device_id FROM devices WHERE user_id = ? AND " + clause | |
451 | ||
452 | # TODO: Maybe this needs to be done in batches if there are | |
453 | # too many local devices for a given user. | |
454 | txn.execute(sql, [user_id] + list(args)) | |
455 | for row in txn: | |
506 | ||
507 | for row in rows: | |
456 | 508 | # Only insert into the local inbox if the device exists on |
457 | 509 | # this server |
458 | device = row[0] | |
459 | message_json = json_encoder.encode(messages_by_device[device]) | |
460 | messages_json_for_user[device] = message_json | |
510 | device_id = row["device_id"] | |
511 | message_json = json_encoder.encode(messages_by_device[device_id]) | |
512 | messages_json_for_user[device_id] = message_json | |
461 | 513 | |
462 | 514 | if messages_json_for_user: |
463 | 515 | local_by_user_then_device[user_id] = messages_json_for_user |
465 | 517 | if not local_by_user_then_device: |
466 | 518 | return |
467 | 519 | |
468 | sql = ( | |
469 | "INSERT INTO device_inbox" | |
470 | " (user_id, device_id, stream_id, message_json)" | |
471 | " VALUES (?,?,?,?)" | |
472 | ) | |
473 | rows = [] | |
474 | for user_id, messages_by_device in local_by_user_then_device.items(): | |
475 | for device_id, message_json in messages_by_device.items(): | |
476 | rows.append((user_id, device_id, stream_id, message_json)) | |
477 | ||
478 | txn.executemany(sql, rows) | |
520 | self.db_pool.simple_insert_many_txn( | |
521 | txn, | |
522 | table="device_inbox", | |
523 | values=[ | |
524 | { | |
525 | "user_id": user_id, | |
526 | "device_id": device_id, | |
527 | "stream_id": stream_id, | |
528 | "message_json": message_json, | |
529 | "instance_name": self._instance_name, | |
530 | } | |
531 | for user_id, messages_by_device in local_by_user_then_device.items() | |
532 | for device_id, message_json in messages_by_device.items() | |
533 | ], | |
534 | ) | |
535 | ||
536 | ||
537 | class DeviceInboxBackgroundUpdateStore(SQLBaseStore): | |
538 | DEVICE_INBOX_STREAM_ID = "device_inbox_stream_drop" | |
539 | ||
540 | def __init__(self, database: DatabasePool, db_conn, hs): | |
541 | super().__init__(database, db_conn, hs) | |
542 | ||
543 | self.db_pool.updates.register_background_index_update( | |
544 | "device_inbox_stream_index", | |
545 | index_name="device_inbox_stream_id_user_id", | |
546 | table="device_inbox", | |
547 | columns=["stream_id", "user_id"], | |
548 | ) | |
549 | ||
550 | self.db_pool.updates.register_background_update_handler( | |
551 | self.DEVICE_INBOX_STREAM_ID, self._background_drop_index_device_inbox | |
552 | ) | |
553 | ||
554 | async def _background_drop_index_device_inbox(self, progress, batch_size): | |
555 | def reindex_txn(conn): | |
556 | txn = conn.cursor() | |
557 | txn.execute("DROP INDEX IF EXISTS device_inbox_stream_id") | |
558 | txn.close() | |
559 | ||
560 | await self.db_pool.runWithConnection(reindex_txn) | |
561 | ||
562 | await self.db_pool.updates._end_background_update(self.DEVICE_INBOX_STREAM_ID) | |
563 | ||
564 | return 1 | |
565 | ||
566 | ||
567 | class DeviceInboxStore(DeviceInboxWorkerStore, DeviceInboxBackgroundUpdateStore): | |
568 | pass |
24 | 24 | from synapse.logging.opentracing import log_kv, set_tag, trace |
25 | 25 | from synapse.storage._base import SQLBaseStore, db_to_json |
26 | 26 | from synapse.storage.database import DatabasePool, make_in_list_sql_clause |
27 | from synapse.storage.engines import PostgresEngine | |
27 | 28 | from synapse.storage.types import Cursor |
28 | 29 | from synapse.types import JsonDict |
29 | 30 | from synapse.util import json_encoder |
512 | 513 | |
513 | 514 | for user_chunk in batch_iter(user_ids, 100): |
514 | 515 | clause, params = make_in_list_sql_clause( |
515 | txn.database_engine, "k.user_id", user_chunk | |
516 | ) | |
517 | sql = ( | |
518 | """ | |
519 | SELECT k.user_id, k.keytype, k.keydata, k.stream_id | |
520 | FROM e2e_cross_signing_keys k | |
521 | INNER JOIN (SELECT user_id, keytype, MAX(stream_id) AS stream_id | |
522 | FROM e2e_cross_signing_keys | |
523 | GROUP BY user_id, keytype) s | |
524 | USING (user_id, stream_id, keytype) | |
525 | WHERE | |
526 | """ | |
527 | + clause | |
528 | ) | |
516 | txn.database_engine, "user_id", user_chunk | |
517 | ) | |
518 | ||
519 | # Fetch the latest key for each type per user. | |
520 | if isinstance(self.database_engine, PostgresEngine): | |
521 | # The `DISTINCT ON` clause will pick the *first* row it | |
522 | # encounters, so ordering by stream ID desc will ensure we get | |
523 | # the latest key. | |
524 | sql = """ | |
525 | SELECT DISTINCT ON (user_id, keytype) user_id, keytype, keydata, stream_id | |
526 | FROM e2e_cross_signing_keys | |
527 | WHERE %(clause)s | |
528 | ORDER BY user_id, keytype, stream_id DESC | |
529 | """ % { | |
530 | "clause": clause | |
531 | } | |
532 | else: | |
533 | # SQLite has special handling for bare columns when using | |
534 | # MIN/MAX with a `GROUP BY` clause where it picks the value from | |
535 | # a row that matches the MIN/MAX. | |
536 | sql = """ | |
537 | SELECT user_id, keytype, keydata, MAX(stream_id) | |
538 | FROM e2e_cross_signing_keys | |
539 | WHERE %(clause)s | |
540 | GROUP BY user_id, keytype | |
541 | """ % { | |
542 | "clause": clause | |
543 | } | |
529 | 544 | |
530 | 545 | txn.execute(sql, params) |
531 | 546 | rows = self.db_pool.cursor_to_dict(txn) |
705 | 720 | def get_device_stream_token(self) -> int: |
706 | 721 | """Get the current stream id from the _device_list_id_gen""" |
707 | 722 | ... |
708 | ||
709 | ||
710 | class EndToEndKeyStore(EndToEndKeyWorkerStore, SQLBaseStore): | |
711 | async def set_e2e_device_keys( | |
712 | self, user_id: str, device_id: str, time_now: int, device_keys: JsonDict | |
713 | ) -> bool: | |
714 | """Stores device keys for a device. Returns whether there was a change | |
715 | or the keys were already in the database. | |
716 | """ | |
717 | ||
718 | def _set_e2e_device_keys_txn(txn): | |
719 | set_tag("user_id", user_id) | |
720 | set_tag("device_id", device_id) | |
721 | set_tag("time_now", time_now) | |
722 | set_tag("device_keys", device_keys) | |
723 | ||
724 | old_key_json = self.db_pool.simple_select_one_onecol_txn( | |
725 | txn, | |
726 | table="e2e_device_keys_json", | |
727 | keyvalues={"user_id": user_id, "device_id": device_id}, | |
728 | retcol="key_json", | |
729 | allow_none=True, | |
730 | ) | |
731 | ||
732 | # In py3 we need old_key_json to match new_key_json type. The DB | |
733 | # returns unicode while encode_canonical_json returns bytes. | |
734 | new_key_json = encode_canonical_json(device_keys).decode("utf-8") | |
735 | ||
736 | if old_key_json == new_key_json: | |
737 | log_kv({"Message": "Device key already stored."}) | |
738 | return False | |
739 | ||
740 | self.db_pool.simple_upsert_txn( | |
741 | txn, | |
742 | table="e2e_device_keys_json", | |
743 | keyvalues={"user_id": user_id, "device_id": device_id}, | |
744 | values={"ts_added_ms": time_now, "key_json": new_key_json}, | |
745 | ) | |
746 | log_kv({"message": "Device keys stored."}) | |
747 | return True | |
748 | ||
749 | return await self.db_pool.runInteraction( | |
750 | "set_e2e_device_keys", _set_e2e_device_keys_txn | |
751 | ) | |
752 | 723 | |
753 | 724 | async def claim_e2e_one_time_keys( |
754 | 725 | self, query_list: Iterable[Tuple[str, str, str]] |
839 | 810 | "claim_e2e_one_time_keys", _claim_e2e_one_time_keys |
840 | 811 | ) |
841 | 812 | |
813 | ||
814 | class EndToEndKeyStore(EndToEndKeyWorkerStore, SQLBaseStore): | |
815 | async def set_e2e_device_keys( | |
816 | self, user_id: str, device_id: str, time_now: int, device_keys: JsonDict | |
817 | ) -> bool: | |
818 | """Stores device keys for a device. Returns whether there was a change | |
819 | or the keys were already in the database. | |
820 | """ | |
821 | ||
822 | def _set_e2e_device_keys_txn(txn): | |
823 | set_tag("user_id", user_id) | |
824 | set_tag("device_id", device_id) | |
825 | set_tag("time_now", time_now) | |
826 | set_tag("device_keys", device_keys) | |
827 | ||
828 | old_key_json = self.db_pool.simple_select_one_onecol_txn( | |
829 | txn, | |
830 | table="e2e_device_keys_json", | |
831 | keyvalues={"user_id": user_id, "device_id": device_id}, | |
832 | retcol="key_json", | |
833 | allow_none=True, | |
834 | ) | |
835 | ||
836 | # In py3 we need old_key_json to match new_key_json type. The DB | |
837 | # returns unicode while encode_canonical_json returns bytes. | |
838 | new_key_json = encode_canonical_json(device_keys).decode("utf-8") | |
839 | ||
840 | if old_key_json == new_key_json: | |
841 | log_kv({"Message": "Device key already stored."}) | |
842 | return False | |
843 | ||
844 | self.db_pool.simple_upsert_txn( | |
845 | txn, | |
846 | table="e2e_device_keys_json", | |
847 | keyvalues={"user_id": user_id, "device_id": device_id}, | |
848 | values={"ts_added_ms": time_now, "key_json": new_key_json}, | |
849 | ) | |
850 | log_kv({"message": "Device keys stored."}) | |
851 | return True | |
852 | ||
853 | return await self.db_pool.runInteraction( | |
854 | "set_e2e_device_keys", _set_e2e_device_keys_txn | |
855 | ) | |
856 | ||
842 | 857 | async def delete_e2e_keys_by_device(self, user_id: str, device_id: str) -> None: |
843 | 858 | def delete_e2e_keys_by_device_txn(txn): |
844 | 859 | log_kv( |
23 | 23 | from synapse.storage.database import DatabasePool, LoggingTransaction |
24 | 24 | from synapse.storage.databases.main.events_worker import EventsWorkerStore |
25 | 25 | from synapse.storage.databases.main.signatures import SignatureWorkerStore |
26 | from synapse.storage.engines import PostgresEngine | |
27 | from synapse.storage.types import Cursor | |
26 | 28 | from synapse.types import Collection |
27 | 29 | from synapse.util.caches.descriptors import cached |
28 | 30 | from synapse.util.caches.lrucache import LruCache |
29 | 31 | from synapse.util.iterutils import batch_iter |
30 | 32 | |
31 | 33 | logger = logging.getLogger(__name__) |
34 | ||
35 | ||
36 | class _NoChainCoverIndex(Exception): | |
37 | def __init__(self, room_id: str): | |
38 | super().__init__("Unexpectedly no chain cover for events in %s" % (room_id,)) | |
32 | 39 | |
33 | 40 | |
34 | 41 | class EventFederationWorkerStore(EventsWorkerStore, SignatureWorkerStore, SQLBaseStore): |
150 | 157 | The set of the difference in auth chains. |
151 | 158 | """ |
152 | 159 | |
160 | # Check if we have indexed the room so we can use the chain cover | |
161 | # algorithm. | |
162 | room = await self.get_room(room_id) | |
163 | if room["has_auth_chain_index"]: | |
164 | try: | |
165 | return await self.db_pool.runInteraction( | |
166 | "get_auth_chain_difference_chains", | |
167 | self._get_auth_chain_difference_using_cover_index_txn, | |
168 | room_id, | |
169 | state_sets, | |
170 | ) | |
171 | except _NoChainCoverIndex: | |
172 | # For whatever reason we don't actually have a chain cover index | |
173 | # for the events in question, so we fall back to the old method. | |
174 | pass | |
175 | ||
153 | 176 | return await self.db_pool.runInteraction( |
154 | 177 | "get_auth_chain_difference", |
155 | 178 | self._get_auth_chain_difference_txn, |
156 | 179 | state_sets, |
157 | 180 | ) |
158 | 181 | |
182 | def _get_auth_chain_difference_using_cover_index_txn( | |
183 | self, txn: Cursor, room_id: str, state_sets: List[Set[str]] | |
184 | ) -> Set[str]: | |
185 | """Calculates the auth chain difference using the chain index. | |
186 | ||
187 | See docs/auth_chain_difference_algorithm.md for details | |
188 | """ | |
189 | ||
190 | # First we look up the chain ID/sequence numbers for all the events, and | |
191 | # work out the chain/sequence numbers reachable from each state set. | |
192 | ||
193 | initial_events = set(state_sets[0]).union(*state_sets[1:]) | |
194 | ||
195 | # Map from event_id -> (chain ID, seq no) | |
196 | chain_info = {} # type: Dict[str, Tuple[int, int]] | |
197 | ||
198 | # Map from chain ID -> seq no -> event Id | |
199 | chain_to_event = {} # type: Dict[int, Dict[int, str]] | |
200 | ||
201 | # All the chains that we've found that are reachable from the state | |
202 | # sets. | |
203 | seen_chains = set() # type: Set[int] | |
204 | ||
205 | sql = """ | |
206 | SELECT event_id, chain_id, sequence_number | |
207 | FROM event_auth_chains | |
208 | WHERE %s | |
209 | """ | |
210 | for batch in batch_iter(initial_events, 1000): | |
211 | clause, args = make_in_list_sql_clause( | |
212 | txn.database_engine, "event_id", batch | |
213 | ) | |
214 | txn.execute(sql % (clause,), args) | |
215 | ||
216 | for event_id, chain_id, sequence_number in txn: | |
217 | chain_info[event_id] = (chain_id, sequence_number) | |
218 | seen_chains.add(chain_id) | |
219 | chain_to_event.setdefault(chain_id, {})[sequence_number] = event_id | |
220 | ||
221 | # Check that we actually have a chain ID for all the events. | |
222 | events_missing_chain_info = initial_events.difference(chain_info) | |
223 | if events_missing_chain_info: | |
224 | # This can happen due to e.g. downgrade/upgrade of the server. We | |
225 | # raise an exception and fall back to the previous algorithm. | |
226 | logger.info( | |
227 | "Unexpectedly found that events don't have chain IDs in room %s: %s", | |
228 | room_id, | |
229 | events_missing_chain_info, | |
230 | ) | |
231 | raise _NoChainCoverIndex(room_id) | |
232 | ||
233 | # Corresponds to `state_sets`, except as a map from chain ID to max | |
234 | # sequence number reachable from the state set. | |
235 | set_to_chain = [] # type: List[Dict[int, int]] | |
236 | for state_set in state_sets: | |
237 | chains = {} # type: Dict[int, int] | |
238 | set_to_chain.append(chains) | |
239 | ||
240 | for event_id in state_set: | |
241 | chain_id, seq_no = chain_info[event_id] | |
242 | ||
243 | chains[chain_id] = max(seq_no, chains.get(chain_id, 0)) | |
244 | ||
245 | # Now we look up all links for the chains we have, adding chains to | |
246 | # set_to_chain that are reachable from each set. | |
247 | sql = """ | |
248 | SELECT | |
249 | origin_chain_id, origin_sequence_number, | |
250 | target_chain_id, target_sequence_number | |
251 | FROM event_auth_chain_links | |
252 | WHERE %s | |
253 | """ | |
254 | ||
255 | # (We need to take a copy of `seen_chains` as we want to mutate it in | |
256 | # the loop) | |
257 | for batch in batch_iter(set(seen_chains), 1000): | |
258 | clause, args = make_in_list_sql_clause( | |
259 | txn.database_engine, "origin_chain_id", batch | |
260 | ) | |
261 | txn.execute(sql % (clause,), args) | |
262 | ||
263 | for ( | |
264 | origin_chain_id, | |
265 | origin_sequence_number, | |
266 | target_chain_id, | |
267 | target_sequence_number, | |
268 | ) in txn: | |
269 | for chains in set_to_chain: | |
270 | # chains are only reachable if the origin sequence number of | |
271 | # the link is less than the max sequence number in the | |
272 | # origin chain. | |
273 | if origin_sequence_number <= chains.get(origin_chain_id, 0): | |
274 | chains[target_chain_id] = max( | |
275 | target_sequence_number, chains.get(target_chain_id, 0), | |
276 | ) | |
277 | ||
278 | seen_chains.add(target_chain_id) | |
279 | ||
280 | # Now for each chain we figure out the maximum sequence number reachable | |
281 | # from *any* state set and the minimum sequence number reachable from | |
282 | # *all* state sets. Events in that range are in the auth chain | |
283 | # difference. | |
284 | result = set() | |
285 | ||
286 | # Mapping from chain ID to the range of sequence numbers that should be | |
287 | # pulled from the database. | |
288 | chain_to_gap = {} # type: Dict[int, Tuple[int, int]] | |
289 | ||
290 | for chain_id in seen_chains: | |
291 | min_seq_no = min(chains.get(chain_id, 0) for chains in set_to_chain) | |
292 | max_seq_no = max(chains.get(chain_id, 0) for chains in set_to_chain) | |
293 | ||
294 | if min_seq_no < max_seq_no: | |
295 | # We have a non empty gap, try and fill it from the events that | |
296 | # we have, otherwise add them to the list of gaps to pull out | |
297 | # from the DB. | |
298 | for seq_no in range(min_seq_no + 1, max_seq_no + 1): | |
299 | event_id = chain_to_event.get(chain_id, {}).get(seq_no) | |
300 | if event_id: | |
301 | result.add(event_id) | |
302 | else: | |
303 | chain_to_gap[chain_id] = (min_seq_no, max_seq_no) | |
304 | break | |
305 | ||
306 | if not chain_to_gap: | |
307 | # If there are no gaps to fetch, we're done! | |
308 | return result | |
309 | ||
310 | if isinstance(self.database_engine, PostgresEngine): | |
311 | # We can use `execute_values` to efficiently fetch the gaps when | |
312 | # using postgres. | |
313 | sql = """ | |
314 | SELECT event_id | |
315 | FROM event_auth_chains AS c, (VALUES ?) AS l(chain_id, min_seq, max_seq) | |
316 | WHERE | |
317 | c.chain_id = l.chain_id | |
318 | AND min_seq < sequence_number AND sequence_number <= max_seq | |
319 | """ | |
320 | ||
321 | args = [ | |
322 | (chain_id, min_no, max_no) | |
323 | for chain_id, (min_no, max_no) in chain_to_gap.items() | |
324 | ] | |
325 | ||
326 | rows = txn.execute_values(sql, args) | |
327 | result.update(r for r, in rows) | |
328 | else: | |
329 | # For SQLite we just fall back to doing a noddy for loop. | |
330 | sql = """ | |
331 | SELECT event_id FROM event_auth_chains | |
332 | WHERE chain_id = ? AND ? < sequence_number AND sequence_number <= ? | |
333 | """ | |
334 | for chain_id, (min_no, max_no) in chain_to_gap.items(): | |
335 | txn.execute(sql, (chain_id, min_no, max_no)) | |
336 | result.update(r for r, in txn) | |
337 | ||
338 | return result | |
339 | ||
159 | 340 | def _get_auth_chain_difference_txn( |
160 | 341 | self, txn, state_sets: List[Set[str]] |
161 | 342 | ) -> Set[str]: |
343 | """Calculates the auth chain difference using a breadth first search. | |
344 | ||
345 | This is used when we don't have a cover index for the room. | |
346 | """ | |
162 | 347 | |
163 | 348 | # Algorithm Description |
164 | 349 | # ~~~~~~~~~~~~~~~~~~~~~ |
834 | 834 | (rotate_to_stream_ordering,), |
835 | 835 | ) |
836 | 836 | |
837 | ||
838 | class EventPushActionsStore(EventPushActionsWorkerStore): | |
839 | EPA_HIGHLIGHT_INDEX = "epa_highlight_index" | |
840 | ||
841 | def __init__(self, database: DatabasePool, db_conn, hs): | |
842 | super().__init__(database, db_conn, hs) | |
843 | ||
844 | self.db_pool.updates.register_background_index_update( | |
845 | self.EPA_HIGHLIGHT_INDEX, | |
846 | index_name="event_push_actions_u_highlight", | |
847 | table="event_push_actions", | |
848 | columns=["user_id", "stream_ordering"], | |
849 | ) | |
850 | ||
851 | self.db_pool.updates.register_background_index_update( | |
852 | "event_push_actions_highlights_index", | |
853 | index_name="event_push_actions_highlights_index", | |
854 | table="event_push_actions", | |
855 | columns=["user_id", "room_id", "topological_ordering", "stream_ordering"], | |
856 | where_clause="highlight=1", | |
857 | ) | |
858 | ||
859 | async def get_push_actions_for_user( | |
860 | self, user_id, before=None, limit=50, only_highlight=False | |
861 | ): | |
862 | def f(txn): | |
863 | before_clause = "" | |
864 | if before: | |
865 | before_clause = "AND epa.stream_ordering < ?" | |
866 | args = [user_id, before, limit] | |
867 | else: | |
868 | args = [user_id, limit] | |
869 | ||
870 | if only_highlight: | |
871 | if len(before_clause) > 0: | |
872 | before_clause += " " | |
873 | before_clause += "AND epa.highlight = 1" | |
874 | ||
875 | # NB. This assumes event_ids are globally unique since | |
876 | # it makes the query easier to index | |
877 | sql = ( | |
878 | "SELECT epa.event_id, epa.room_id," | |
879 | " epa.stream_ordering, epa.topological_ordering," | |
880 | " epa.actions, epa.highlight, epa.profile_tag, e.received_ts" | |
881 | " FROM event_push_actions epa, events e" | |
882 | " WHERE epa.event_id = e.event_id" | |
883 | " AND epa.user_id = ? %s" | |
884 | " AND epa.notif = 1" | |
885 | " ORDER BY epa.stream_ordering DESC" | |
886 | " LIMIT ?" % (before_clause,) | |
887 | ) | |
888 | txn.execute(sql, args) | |
889 | return self.db_pool.cursor_to_dict(txn) | |
890 | ||
891 | push_actions = await self.db_pool.runInteraction("get_push_actions_for_user", f) | |
892 | for pa in push_actions: | |
893 | pa["actions"] = _deserialize_action(pa["actions"], pa["highlight"]) | |
894 | return push_actions | |
895 | ||
896 | 837 | def _remove_old_push_actions_before_txn( |
897 | 838 | self, txn, room_id, user_id, stream_ordering |
898 | 839 | ): |
940 | 881 | ) |
941 | 882 | |
942 | 883 | |
884 | class EventPushActionsStore(EventPushActionsWorkerStore): | |
885 | EPA_HIGHLIGHT_INDEX = "epa_highlight_index" | |
886 | ||
887 | def __init__(self, database: DatabasePool, db_conn, hs): | |
888 | super().__init__(database, db_conn, hs) | |
889 | ||
890 | self.db_pool.updates.register_background_index_update( | |
891 | self.EPA_HIGHLIGHT_INDEX, | |
892 | index_name="event_push_actions_u_highlight", | |
893 | table="event_push_actions", | |
894 | columns=["user_id", "stream_ordering"], | |
895 | ) | |
896 | ||
897 | self.db_pool.updates.register_background_index_update( | |
898 | "event_push_actions_highlights_index", | |
899 | index_name="event_push_actions_highlights_index", | |
900 | table="event_push_actions", | |
901 | columns=["user_id", "room_id", "topological_ordering", "stream_ordering"], | |
902 | where_clause="highlight=1", | |
903 | ) | |
904 | ||
905 | async def get_push_actions_for_user( | |
906 | self, user_id, before=None, limit=50, only_highlight=False | |
907 | ): | |
908 | def f(txn): | |
909 | before_clause = "" | |
910 | if before: | |
911 | before_clause = "AND epa.stream_ordering < ?" | |
912 | args = [user_id, before, limit] | |
913 | else: | |
914 | args = [user_id, limit] | |
915 | ||
916 | if only_highlight: | |
917 | if len(before_clause) > 0: | |
918 | before_clause += " " | |
919 | before_clause += "AND epa.highlight = 1" | |
920 | ||
921 | # NB. This assumes event_ids are globally unique since | |
922 | # it makes the query easier to index | |
923 | sql = ( | |
924 | "SELECT epa.event_id, epa.room_id," | |
925 | " epa.stream_ordering, epa.topological_ordering," | |
926 | " epa.actions, epa.highlight, epa.profile_tag, e.received_ts" | |
927 | " FROM event_push_actions epa, events e" | |
928 | " WHERE epa.event_id = e.event_id" | |
929 | " AND epa.user_id = ? %s" | |
930 | " AND epa.notif = 1" | |
931 | " ORDER BY epa.stream_ordering DESC" | |
932 | " LIMIT ?" % (before_clause,) | |
933 | ) | |
934 | txn.execute(sql, args) | |
935 | return self.db_pool.cursor_to_dict(txn) | |
936 | ||
937 | push_actions = await self.db_pool.runInteraction("get_push_actions_for_user", f) | |
938 | for pa in push_actions: | |
939 | pa["actions"] = _deserialize_action(pa["actions"], pa["highlight"]) | |
940 | return push_actions | |
941 | ||
942 | ||
943 | 943 | def _action_has_highlight(actions): |
944 | 944 | for action in actions: |
945 | 945 | try: |
16 | 16 | import itertools |
17 | 17 | import logging |
18 | 18 | from collections import OrderedDict, namedtuple |
19 | from typing import TYPE_CHECKING, Any, Dict, Iterable, List, Optional, Set, Tuple | |
19 | from typing import ( | |
20 | TYPE_CHECKING, | |
21 | Any, | |
22 | Dict, | |
23 | Generator, | |
24 | Iterable, | |
25 | List, | |
26 | Optional, | |
27 | Set, | |
28 | Tuple, | |
29 | ) | |
20 | 30 | |
21 | 31 | import attr |
22 | 32 | from prometheus_client import Counter |
34 | 44 | from synapse.storage.util.id_generators import MultiWriterIdGenerator |
35 | 45 | from synapse.types import StateMap, get_domain_from_id |
36 | 46 | from synapse.util import json_encoder |
37 | from synapse.util.iterutils import batch_iter | |
47 | from synapse.util.iterutils import batch_iter, sorted_topologically | |
38 | 48 | |
39 | 49 | if TYPE_CHECKING: |
40 | 50 | from synapse.server import HomeServer |
365 | 375 | # Insert into event_to_state_groups. |
366 | 376 | self._store_event_state_mappings_txn(txn, events_and_contexts) |
367 | 377 | |
378 | self._persist_event_auth_chain_txn(txn, [e for e, _ in events_and_contexts]) | |
379 | ||
380 | # _store_rejected_events_txn filters out any events which were | |
381 | # rejected, and returns the filtered list. | |
382 | events_and_contexts = self._store_rejected_events_txn( | |
383 | txn, events_and_contexts=events_and_contexts | |
384 | ) | |
385 | ||
386 | # From this point onwards the events are only ones that weren't | |
387 | # rejected. | |
388 | ||
389 | self._update_metadata_tables_txn( | |
390 | txn, | |
391 | events_and_contexts=events_and_contexts, | |
392 | all_events_and_contexts=all_events_and_contexts, | |
393 | backfilled=backfilled, | |
394 | ) | |
395 | ||
396 | # We call this last as it assumes we've inserted the events into | |
397 | # room_memberships, where applicable. | |
398 | self._update_current_state_txn(txn, state_delta_for_room, min_stream_order) | |
399 | ||
400 | def _persist_event_auth_chain_txn( | |
401 | self, txn: LoggingTransaction, events: List[EventBase], | |
402 | ) -> None: | |
403 | ||
404 | # We only care about state events, so this if there are no state events. | |
405 | if not any(e.is_state() for e in events): | |
406 | return | |
407 | ||
368 | 408 | # We want to store event_auth mappings for rejected events, as they're |
369 | 409 | # used in state res v2. |
370 | 410 | # This is only necessary if the rejected event appears in an accepted |
380 | 420 | "room_id": event.room_id, |
381 | 421 | "auth_id": auth_id, |
382 | 422 | } |
383 | for event, _ in events_and_contexts | |
423 | for event in events | |
384 | 424 | for auth_id in event.auth_event_ids() |
385 | 425 | if event.is_state() |
386 | 426 | ], |
387 | 427 | ) |
388 | 428 | |
389 | # _store_rejected_events_txn filters out any events which were | |
390 | # rejected, and returns the filtered list. | |
391 | events_and_contexts = self._store_rejected_events_txn( | |
392 | txn, events_and_contexts=events_and_contexts | |
393 | ) | |
394 | ||
395 | # From this point onwards the events are only ones that weren't | |
396 | # rejected. | |
397 | ||
398 | self._update_metadata_tables_txn( | |
399 | txn, | |
400 | events_and_contexts=events_and_contexts, | |
401 | all_events_and_contexts=all_events_and_contexts, | |
402 | backfilled=backfilled, | |
403 | ) | |
404 | ||
405 | # We call this last as it assumes we've inserted the events into | |
406 | # room_memberships, where applicable. | |
407 | self._update_current_state_txn(txn, state_delta_for_room, min_stream_order) | |
429 | # We now calculate chain ID/sequence numbers for any state events we're | |
430 | # persisting. We ignore out of band memberships as we're not in the room | |
431 | # and won't have their auth chain (we'll fix it up later if we join the | |
432 | # room). | |
433 | # | |
434 | # See: docs/auth_chain_difference_algorithm.md | |
435 | ||
436 | # We ignore legacy rooms that we aren't filling the chain cover index | |
437 | # for. | |
438 | rows = self.db_pool.simple_select_many_txn( | |
439 | txn, | |
440 | table="rooms", | |
441 | column="room_id", | |
442 | iterable={event.room_id for event in events if event.is_state()}, | |
443 | keyvalues={}, | |
444 | retcols=("room_id", "has_auth_chain_index"), | |
445 | ) | |
446 | rooms_using_chain_index = { | |
447 | row["room_id"] for row in rows if row["has_auth_chain_index"] | |
448 | } | |
449 | ||
450 | state_events = { | |
451 | event.event_id: event | |
452 | for event in events | |
453 | if event.is_state() and event.room_id in rooms_using_chain_index | |
454 | } | |
455 | ||
456 | if not state_events: | |
457 | return | |
458 | ||
459 | # We need to know the type/state_key and auth events of the events we're | |
460 | # calculating chain IDs for. We don't rely on having the full Event | |
461 | # instances as we'll potentially be pulling more events from the DB and | |
462 | # we don't need the overhead of fetching/parsing the full event JSON. | |
463 | event_to_types = { | |
464 | e.event_id: (e.type, e.state_key) for e in state_events.values() | |
465 | } | |
466 | event_to_auth_chain = { | |
467 | e.event_id: e.auth_event_ids() for e in state_events.values() | |
468 | } | |
469 | event_to_room_id = {e.event_id: e.room_id for e in state_events.values()} | |
470 | ||
471 | self._add_chain_cover_index( | |
472 | txn, self.db_pool, event_to_room_id, event_to_types, event_to_auth_chain, | |
473 | ) | |
474 | ||
475 | @staticmethod | |
476 | def _add_chain_cover_index( | |
477 | txn, | |
478 | db_pool: DatabasePool, | |
479 | event_to_room_id: Dict[str, str], | |
480 | event_to_types: Dict[str, Tuple[str, str]], | |
481 | event_to_auth_chain: Dict[str, List[str]], | |
482 | ) -> None: | |
483 | """Calculate the chain cover index for the given events. | |
484 | ||
485 | Args: | |
486 | event_to_room_id: Event ID to the room ID of the event | |
487 | event_to_types: Event ID to type and state_key of the event | |
488 | event_to_auth_chain: Event ID to list of auth event IDs of the | |
489 | event (events with no auth events can be excluded). | |
490 | """ | |
491 | ||
492 | # Map from event ID to chain ID/sequence number. | |
493 | chain_map = {} # type: Dict[str, Tuple[int, int]] | |
494 | ||
495 | # Set of event IDs to calculate chain ID/seq numbers for. | |
496 | events_to_calc_chain_id_for = set(event_to_room_id) | |
497 | ||
498 | # We check if there are any events that need to be handled in the rooms | |
499 | # we're looking at. These should just be out of band memberships, where | |
500 | # we didn't have the auth chain when we first persisted. | |
501 | rows = db_pool.simple_select_many_txn( | |
502 | txn, | |
503 | table="event_auth_chain_to_calculate", | |
504 | keyvalues={}, | |
505 | column="room_id", | |
506 | iterable=set(event_to_room_id.values()), | |
507 | retcols=("event_id", "type", "state_key"), | |
508 | ) | |
509 | for row in rows: | |
510 | event_id = row["event_id"] | |
511 | event_type = row["type"] | |
512 | state_key = row["state_key"] | |
513 | ||
514 | # (We could pull out the auth events for all rows at once using | |
515 | # simple_select_many, but this case happens rarely and almost always | |
516 | # with a single row.) | |
517 | auth_events = db_pool.simple_select_onecol_txn( | |
518 | txn, "event_auth", keyvalues={"event_id": event_id}, retcol="auth_id", | |
519 | ) | |
520 | ||
521 | events_to_calc_chain_id_for.add(event_id) | |
522 | event_to_types[event_id] = (event_type, state_key) | |
523 | event_to_auth_chain[event_id] = auth_events | |
524 | ||
525 | # First we get the chain ID and sequence numbers for the events' | |
526 | # auth events (that aren't also currently being persisted). | |
527 | # | |
528 | # Note that there there is an edge case here where we might not have | |
529 | # calculated chains and sequence numbers for events that were "out | |
530 | # of band". We handle this case by fetching the necessary info and | |
531 | # adding it to the set of events to calculate chain IDs for. | |
532 | ||
533 | missing_auth_chains = { | |
534 | a_id | |
535 | for auth_events in event_to_auth_chain.values() | |
536 | for a_id in auth_events | |
537 | if a_id not in events_to_calc_chain_id_for | |
538 | } | |
539 | ||
540 | # We loop here in case we find an out of band membership and need to | |
541 | # fetch their auth event info. | |
542 | while missing_auth_chains: | |
543 | sql = """ | |
544 | SELECT event_id, events.type, state_key, chain_id, sequence_number | |
545 | FROM events | |
546 | INNER JOIN state_events USING (event_id) | |
547 | LEFT JOIN event_auth_chains USING (event_id) | |
548 | WHERE | |
549 | """ | |
550 | clause, args = make_in_list_sql_clause( | |
551 | txn.database_engine, "event_id", missing_auth_chains, | |
552 | ) | |
553 | txn.execute(sql + clause, args) | |
554 | ||
555 | missing_auth_chains.clear() | |
556 | ||
557 | for auth_id, event_type, state_key, chain_id, sequence_number in txn: | |
558 | event_to_types[auth_id] = (event_type, state_key) | |
559 | ||
560 | if chain_id is None: | |
561 | # No chain ID, so the event was persisted out of band. | |
562 | # We add to list of events to calculate auth chains for. | |
563 | ||
564 | events_to_calc_chain_id_for.add(auth_id) | |
565 | ||
566 | event_to_auth_chain[auth_id] = db_pool.simple_select_onecol_txn( | |
567 | txn, | |
568 | "event_auth", | |
569 | keyvalues={"event_id": auth_id}, | |
570 | retcol="auth_id", | |
571 | ) | |
572 | ||
573 | missing_auth_chains.update( | |
574 | e | |
575 | for e in event_to_auth_chain[auth_id] | |
576 | if e not in event_to_types | |
577 | ) | |
578 | else: | |
579 | chain_map[auth_id] = (chain_id, sequence_number) | |
580 | ||
581 | # Now we check if we have any events where we don't have auth chain, | |
582 | # this should only be out of band memberships. | |
583 | for event_id in sorted_topologically(event_to_auth_chain, event_to_auth_chain): | |
584 | for auth_id in event_to_auth_chain[event_id]: | |
585 | if ( | |
586 | auth_id not in chain_map | |
587 | and auth_id not in events_to_calc_chain_id_for | |
588 | ): | |
589 | events_to_calc_chain_id_for.discard(event_id) | |
590 | ||
591 | # If this is an event we're trying to persist we add it to | |
592 | # the list of events to calculate chain IDs for next time | |
593 | # around. (Otherwise we will have already added it to the | |
594 | # table). | |
595 | room_id = event_to_room_id.get(event_id) | |
596 | if room_id: | |
597 | e_type, state_key = event_to_types[event_id] | |
598 | db_pool.simple_insert_txn( | |
599 | txn, | |
600 | table="event_auth_chain_to_calculate", | |
601 | values={ | |
602 | "event_id": event_id, | |
603 | "room_id": room_id, | |
604 | "type": e_type, | |
605 | "state_key": state_key, | |
606 | }, | |
607 | ) | |
608 | ||
609 | # We stop checking the event's auth events since we've | |
610 | # discarded it. | |
611 | break | |
612 | ||
613 | if not events_to_calc_chain_id_for: | |
614 | return | |
615 | ||
616 | # We now calculate the chain IDs/sequence numbers for the events. We | |
617 | # do this by looking at the chain ID and sequence number of any auth | |
618 | # event with the same type/state_key and incrementing the sequence | |
619 | # number by one. If there was no match or the chain ID/sequence | |
620 | # number is already taken we generate a new chain. | |
621 | # | |
622 | # We need to do this in a topologically sorted order as we want to | |
623 | # generate chain IDs/sequence numbers of an event's auth events | |
624 | # before the event itself. | |
625 | chains_tuples_allocated = set() # type: Set[Tuple[int, int]] | |
626 | new_chain_tuples = {} # type: Dict[str, Tuple[int, int]] | |
627 | for event_id in sorted_topologically( | |
628 | events_to_calc_chain_id_for, event_to_auth_chain | |
629 | ): | |
630 | existing_chain_id = None | |
631 | for auth_id in event_to_auth_chain.get(event_id, []): | |
632 | if event_to_types.get(event_id) == event_to_types.get(auth_id): | |
633 | existing_chain_id = chain_map[auth_id] | |
634 | break | |
635 | ||
636 | new_chain_tuple = None | |
637 | if existing_chain_id: | |
638 | # We found a chain ID/sequence number candidate, check its | |
639 | # not already taken. | |
640 | proposed_new_id = existing_chain_id[0] | |
641 | proposed_new_seq = existing_chain_id[1] + 1 | |
642 | if (proposed_new_id, proposed_new_seq) not in chains_tuples_allocated: | |
643 | already_allocated = db_pool.simple_select_one_onecol_txn( | |
644 | txn, | |
645 | table="event_auth_chains", | |
646 | keyvalues={ | |
647 | "chain_id": proposed_new_id, | |
648 | "sequence_number": proposed_new_seq, | |
649 | }, | |
650 | retcol="event_id", | |
651 | allow_none=True, | |
652 | ) | |
653 | if already_allocated: | |
654 | # Mark it as already allocated so we don't need to hit | |
655 | # the DB again. | |
656 | chains_tuples_allocated.add((proposed_new_id, proposed_new_seq)) | |
657 | else: | |
658 | new_chain_tuple = ( | |
659 | proposed_new_id, | |
660 | proposed_new_seq, | |
661 | ) | |
662 | ||
663 | if not new_chain_tuple: | |
664 | new_chain_tuple = (db_pool.event_chain_id_gen.get_next_id_txn(txn), 1) | |
665 | ||
666 | chains_tuples_allocated.add(new_chain_tuple) | |
667 | ||
668 | chain_map[event_id] = new_chain_tuple | |
669 | new_chain_tuples[event_id] = new_chain_tuple | |
670 | ||
671 | db_pool.simple_insert_many_txn( | |
672 | txn, | |
673 | table="event_auth_chains", | |
674 | values=[ | |
675 | {"event_id": event_id, "chain_id": c_id, "sequence_number": seq} | |
676 | for event_id, (c_id, seq) in new_chain_tuples.items() | |
677 | ], | |
678 | ) | |
679 | ||
680 | db_pool.simple_delete_many_txn( | |
681 | txn, | |
682 | table="event_auth_chain_to_calculate", | |
683 | keyvalues={}, | |
684 | column="event_id", | |
685 | iterable=new_chain_tuples, | |
686 | ) | |
687 | ||
688 | # Now we need to calculate any new links between chains caused by | |
689 | # the new events. | |
690 | # | |
691 | # Links are pairs of chain ID/sequence numbers such that for any | |
692 | # event A (CA, SA) and any event B (CB, SB), B is in A's auth chain | |
693 | # if and only if there is at least one link (CA, S1) -> (CB, S2) | |
694 | # where SA >= S1 and S2 >= SB. | |
695 | # | |
696 | # We try and avoid adding redundant links to the table, e.g. if we | |
697 | # have two links between two chains which both start/end at the | |
698 | # sequence number event (or cross) then one can be safely dropped. | |
699 | # | |
700 | # To calculate new links we look at every new event and: | |
701 | # 1. Fetch the chain ID/sequence numbers of its auth events, | |
702 | # discarding any that are reachable by other auth events, or | |
703 | # that have the same chain ID as the event. | |
704 | # 2. For each retained auth event we: | |
705 | # a. Add a link from the event's to the auth event's chain | |
706 | # ID/sequence number; and | |
707 | # b. Add a link from the event to every chain reachable by the | |
708 | # auth event. | |
709 | ||
710 | # Step 1, fetch all existing links from all the chains we've seen | |
711 | # referenced. | |
712 | chain_links = _LinkMap() | |
713 | rows = db_pool.simple_select_many_txn( | |
714 | txn, | |
715 | table="event_auth_chain_links", | |
716 | column="origin_chain_id", | |
717 | iterable={chain_id for chain_id, _ in chain_map.values()}, | |
718 | keyvalues={}, | |
719 | retcols=( | |
720 | "origin_chain_id", | |
721 | "origin_sequence_number", | |
722 | "target_chain_id", | |
723 | "target_sequence_number", | |
724 | ), | |
725 | ) | |
726 | for row in rows: | |
727 | chain_links.add_link( | |
728 | (row["origin_chain_id"], row["origin_sequence_number"]), | |
729 | (row["target_chain_id"], row["target_sequence_number"]), | |
730 | new=False, | |
731 | ) | |
732 | ||
733 | # We do this in toplogical order to avoid adding redundant links. | |
734 | for event_id in sorted_topologically( | |
735 | events_to_calc_chain_id_for, event_to_auth_chain | |
736 | ): | |
737 | chain_id, sequence_number = chain_map[event_id] | |
738 | ||
739 | # Filter out auth events that are reachable by other auth | |
740 | # events. We do this by looking at every permutation of pairs of | |
741 | # auth events (A, B) to check if B is reachable from A. | |
742 | reduction = { | |
743 | a_id | |
744 | for a_id in event_to_auth_chain.get(event_id, []) | |
745 | if chain_map[a_id][0] != chain_id | |
746 | } | |
747 | for start_auth_id, end_auth_id in itertools.permutations( | |
748 | event_to_auth_chain.get(event_id, []), r=2, | |
749 | ): | |
750 | if chain_links.exists_path_from( | |
751 | chain_map[start_auth_id], chain_map[end_auth_id] | |
752 | ): | |
753 | reduction.discard(end_auth_id) | |
754 | ||
755 | # Step 2, figure out what the new links are from the reduced | |
756 | # list of auth events. | |
757 | for auth_id in reduction: | |
758 | auth_chain_id, auth_sequence_number = chain_map[auth_id] | |
759 | ||
760 | # Step 2a, add link between the event and auth event | |
761 | chain_links.add_link( | |
762 | (chain_id, sequence_number), (auth_chain_id, auth_sequence_number) | |
763 | ) | |
764 | ||
765 | # Step 2b, add a link to chains reachable from the auth | |
766 | # event. | |
767 | for target_id, target_seq in chain_links.get_links_from( | |
768 | (auth_chain_id, auth_sequence_number) | |
769 | ): | |
770 | if target_id == chain_id: | |
771 | continue | |
772 | ||
773 | chain_links.add_link( | |
774 | (chain_id, sequence_number), (target_id, target_seq) | |
775 | ) | |
776 | ||
777 | db_pool.simple_insert_many_txn( | |
778 | txn, | |
779 | table="event_auth_chain_links", | |
780 | values=[ | |
781 | { | |
782 | "origin_chain_id": source_id, | |
783 | "origin_sequence_number": source_seq, | |
784 | "target_chain_id": target_id, | |
785 | "target_sequence_number": target_seq, | |
786 | } | |
787 | for ( | |
788 | source_id, | |
789 | source_seq, | |
790 | target_id, | |
791 | target_seq, | |
792 | ) in chain_links.get_additions() | |
793 | ], | |
794 | ) | |
408 | 795 | |
409 | 796 | def _persist_transaction_ids_txn( |
410 | 797 | self, |
798 | 1185 | return [ec for ec in events_and_contexts if ec[0] not in to_remove] |
799 | 1186 | |
800 | 1187 | def _store_event_txn(self, txn, events_and_contexts): |
801 | """Insert new events into the event and event_json tables | |
1188 | """Insert new events into the event, event_json, redaction and | |
1189 | state_events tables. | |
802 | 1190 | |
803 | 1191 | Args: |
804 | 1192 | txn (twisted.enterprise.adbapi.Connection): db connection |
870 | 1258 | updatevalues={"have_censored": False}, |
871 | 1259 | ) |
872 | 1260 | |
1261 | state_events_and_contexts = [ | |
1262 | ec for ec in events_and_contexts if ec[0].is_state() | |
1263 | ] | |
1264 | ||
1265 | state_values = [] | |
1266 | for event, context in state_events_and_contexts: | |
1267 | vals = { | |
1268 | "event_id": event.event_id, | |
1269 | "room_id": event.room_id, | |
1270 | "type": event.type, | |
1271 | "state_key": event.state_key, | |
1272 | } | |
1273 | ||
1274 | # TODO: How does this work with backfilling? | |
1275 | if hasattr(event, "replaces_state"): | |
1276 | vals["prev_state"] = event.replaces_state | |
1277 | ||
1278 | state_values.append(vals) | |
1279 | ||
1280 | self.db_pool.simple_insert_many_txn( | |
1281 | txn, table="state_events", values=state_values | |
1282 | ) | |
1283 | ||
873 | 1284 | def _store_rejected_events_txn(self, txn, events_and_contexts): |
874 | 1285 | """Add rows to the 'rejections' table for received events which were |
875 | 1286 | rejected |
984 | 1395 | # Insert event_reference_hashes table. |
985 | 1396 | self._store_event_reference_hashes_txn( |
986 | 1397 | txn, [event for event, _ in events_and_contexts] |
987 | ) | |
988 | ||
989 | state_events_and_contexts = [ | |
990 | ec for ec in events_and_contexts if ec[0].is_state() | |
991 | ] | |
992 | ||
993 | state_values = [] | |
994 | for event, context in state_events_and_contexts: | |
995 | vals = { | |
996 | "event_id": event.event_id, | |
997 | "room_id": event.room_id, | |
998 | "type": event.type, | |
999 | "state_key": event.state_key, | |
1000 | } | |
1001 | ||
1002 | # TODO: How does this work with backfilling? | |
1003 | if hasattr(event, "replaces_state"): | |
1004 | vals["prev_state"] = event.replaces_state | |
1005 | ||
1006 | state_values.append(vals) | |
1007 | ||
1008 | self.db_pool.simple_insert_many_txn( | |
1009 | txn, table="state_events", values=state_values | |
1010 | 1398 | ) |
1011 | 1399 | |
1012 | 1400 | # Prefill the event cache |
1519 | 1907 | if not ev.internal_metadata.is_outlier() |
1520 | 1908 | ], |
1521 | 1909 | ) |
1910 | ||
1911 | ||
1912 | @attr.s(slots=True) | |
1913 | class _LinkMap: | |
1914 | """A helper type for tracking links between chains. | |
1915 | """ | |
1916 | ||
1917 | # Stores the set of links as nested maps: source chain ID -> target chain ID | |
1918 | # -> source sequence number -> target sequence number. | |
1919 | maps = attr.ib(type=Dict[int, Dict[int, Dict[int, int]]], factory=dict) | |
1920 | ||
1921 | # Stores the links that have been added (with new set to true), as tuples of | |
1922 | # `(source chain ID, source sequence no, target chain ID, target sequence no.)` | |
1923 | additions = attr.ib(type=Set[Tuple[int, int, int, int]], factory=set) | |
1924 | ||
1925 | def add_link( | |
1926 | self, | |
1927 | src_tuple: Tuple[int, int], | |
1928 | target_tuple: Tuple[int, int], | |
1929 | new: bool = True, | |
1930 | ) -> bool: | |
1931 | """Add a new link between two chains, ensuring no redundant links are added. | |
1932 | ||
1933 | New links should be added in topological order. | |
1934 | ||
1935 | Args: | |
1936 | src_tuple: The chain ID/sequence number of the source of the link. | |
1937 | target_tuple: The chain ID/sequence number of the target of the link. | |
1938 | new: Whether this is a "new" link, i.e. should it be returned | |
1939 | by `get_additions`. | |
1940 | ||
1941 | Returns: | |
1942 | True if a link was added, false if the given link was dropped as redundant | |
1943 | """ | |
1944 | src_chain, src_seq = src_tuple | |
1945 | target_chain, target_seq = target_tuple | |
1946 | ||
1947 | current_links = self.maps.setdefault(src_chain, {}).setdefault(target_chain, {}) | |
1948 | ||
1949 | assert src_chain != target_chain | |
1950 | ||
1951 | if new: | |
1952 | # Check if the new link is redundant | |
1953 | for current_seq_src, current_seq_target in current_links.items(): | |
1954 | # If a link "crosses" another link then its redundant. For example | |
1955 | # in the following link 1 (L1) is redundant, as any event reachable | |
1956 | # via L1 is *also* reachable via L2. | |
1957 | # | |
1958 | # Chain A Chain B | |
1959 | # | | | |
1960 | # L1 |------ | | |
1961 | # | | | | |
1962 | # L2 |---- | -->| | |
1963 | # | | | | |
1964 | # | |--->| | |
1965 | # | | | |
1966 | # | | | |
1967 | # | |
1968 | # So we only need to keep links which *do not* cross, i.e. links | |
1969 | # that both start and end above or below an existing link. | |
1970 | # | |
1971 | # Note, since we add links in topological ordering we should never | |
1972 | # see `src_seq` less than `current_seq_src`. | |
1973 | ||
1974 | if current_seq_src <= src_seq and target_seq <= current_seq_target: | |
1975 | # This new link is redundant, nothing to do. | |
1976 | return False | |
1977 | ||
1978 | self.additions.add((src_chain, src_seq, target_chain, target_seq)) | |
1979 | ||
1980 | current_links[src_seq] = target_seq | |
1981 | return True | |
1982 | ||
1983 | def get_links_from( | |
1984 | self, src_tuple: Tuple[int, int] | |
1985 | ) -> Generator[Tuple[int, int], None, None]: | |
1986 | """Gets the chains reachable from the given chain/sequence number. | |
1987 | ||
1988 | Yields: | |
1989 | The chain ID and sequence number the link points to. | |
1990 | """ | |
1991 | src_chain, src_seq = src_tuple | |
1992 | for target_id, sequence_numbers in self.maps.get(src_chain, {}).items(): | |
1993 | for link_src_seq, target_seq in sequence_numbers.items(): | |
1994 | if link_src_seq <= src_seq: | |
1995 | yield target_id, target_seq | |
1996 | ||
1997 | def get_links_between( | |
1998 | self, source_chain: int, target_chain: int | |
1999 | ) -> Generator[Tuple[int, int], None, None]: | |
2000 | """Gets the links between two chains. | |
2001 | ||
2002 | Yields: | |
2003 | The source and target sequence numbers. | |
2004 | """ | |
2005 | ||
2006 | yield from self.maps.get(source_chain, {}).get(target_chain, {}).items() | |
2007 | ||
2008 | def get_additions(self) -> Generator[Tuple[int, int, int, int], None, None]: | |
2009 | """Gets any newly added links. | |
2010 | ||
2011 | Yields: | |
2012 | The source chain ID/sequence number and target chain ID/sequence number | |
2013 | """ | |
2014 | ||
2015 | for src_chain, src_seq, target_chain, _ in self.additions: | |
2016 | target_seq = self.maps.get(src_chain, {}).get(target_chain, {}).get(src_seq) | |
2017 | if target_seq is not None: | |
2018 | yield (src_chain, src_seq, target_chain, target_seq) | |
2019 | ||
2020 | def exists_path_from( | |
2021 | self, src_tuple: Tuple[int, int], target_tuple: Tuple[int, int], | |
2022 | ) -> bool: | |
2023 | """Checks if there is a path between the source chain ID/sequence and | |
2024 | target chain ID/sequence. | |
2025 | """ | |
2026 | src_chain, src_seq = src_tuple | |
2027 | target_chain, target_seq = target_tuple | |
2028 | ||
2029 | if src_chain == target_chain: | |
2030 | return target_seq <= src_seq | |
2031 | ||
2032 | links = self.get_links_between(src_chain, target_chain) | |
2033 | for link_start_seq, link_end_seq in links: | |
2034 | if link_start_seq <= src_seq and target_seq <= link_end_seq: | |
2035 | return True | |
2036 | ||
2037 | return False |
13 | 13 | # limitations under the License. |
14 | 14 | |
15 | 15 | import logging |
16 | from typing import Dict, List, Optional, Tuple | |
17 | ||
18 | import attr | |
16 | 19 | |
17 | 20 | from synapse.api.constants import EventContentFields |
21 | from synapse.api.room_versions import KNOWN_ROOM_VERSIONS | |
22 | from synapse.events import make_event_from_dict | |
18 | 23 | from synapse.storage._base import SQLBaseStore, db_to_json, make_in_list_sql_clause |
19 | from synapse.storage.database import DatabasePool | |
24 | from synapse.storage.database import DatabasePool, make_tuple_comparison_clause | |
25 | from synapse.storage.databases.main.events import PersistEventsStore | |
26 | from synapse.storage.types import Cursor | |
27 | from synapse.types import JsonDict | |
20 | 28 | |
21 | 29 | logger = logging.getLogger(__name__) |
30 | ||
31 | ||
32 | @attr.s(slots=True, frozen=True) | |
33 | class _CalculateChainCover: | |
34 | """Return value for _calculate_chain_cover_txn. | |
35 | """ | |
36 | ||
37 | # The last room_id/depth/stream processed. | |
38 | room_id = attr.ib(type=str) | |
39 | depth = attr.ib(type=int) | |
40 | stream = attr.ib(type=int) | |
41 | ||
42 | # Number of rows processed | |
43 | processed_count = attr.ib(type=int) | |
44 | ||
45 | # Map from room_id to last depth/stream processed for each room that we have | |
46 | # processed all events for (i.e. the rooms we can flip the | |
47 | # `has_auth_chain_index` for) | |
48 | finished_room_map = attr.ib(type=Dict[str, Tuple[int, int]]) | |
22 | 49 | |
23 | 50 | |
24 | 51 | class EventsBackgroundUpdatesStore(SQLBaseStore): |
96 | 123 | index_name="users_have_local_media", |
97 | 124 | table="local_media_repository", |
98 | 125 | columns=["user_id", "created_ts"], |
126 | ) | |
127 | ||
128 | self.db_pool.updates.register_background_update_handler( | |
129 | "rejected_events_metadata", self._rejected_events_metadata, | |
130 | ) | |
131 | ||
132 | self.db_pool.updates.register_background_update_handler( | |
133 | "chain_cover", self._chain_cover_index, | |
99 | 134 | ) |
100 | 135 | |
101 | 136 | async def _background_reindex_fields_sender(self, progress, batch_size): |
581 | 616 | await self.db_pool.updates._end_background_update("event_store_labels") |
582 | 617 | |
583 | 618 | return num_rows |
619 | ||
620 | async def _rejected_events_metadata(self, progress: dict, batch_size: int) -> int: | |
621 | """Adds rejected events to the `state_events` and `event_auth` metadata | |
622 | tables. | |
623 | """ | |
624 | ||
625 | last_event_id = progress.get("last_event_id", "") | |
626 | ||
627 | def get_rejected_events( | |
628 | txn: Cursor, | |
629 | ) -> List[Tuple[str, str, JsonDict, bool, bool]]: | |
630 | # Fetch rejected event json, their room version and whether we have | |
631 | # inserted them into the state_events or auth_events tables. | |
632 | # | |
633 | # Note we can assume that events that don't have a corresponding | |
634 | # room version are V1 rooms. | |
635 | sql = """ | |
636 | SELECT DISTINCT | |
637 | event_id, | |
638 | COALESCE(room_version, '1'), | |
639 | json, | |
640 | state_events.event_id IS NOT NULL, | |
641 | event_auth.event_id IS NOT NULL | |
642 | FROM rejections | |
643 | INNER JOIN event_json USING (event_id) | |
644 | LEFT JOIN rooms USING (room_id) | |
645 | LEFT JOIN state_events USING (event_id) | |
646 | LEFT JOIN event_auth USING (event_id) | |
647 | WHERE event_id > ? | |
648 | ORDER BY event_id | |
649 | LIMIT ? | |
650 | """ | |
651 | ||
652 | txn.execute(sql, (last_event_id, batch_size,)) | |
653 | ||
654 | return [(row[0], row[1], db_to_json(row[2]), row[3], row[4]) for row in txn] # type: ignore | |
655 | ||
656 | results = await self.db_pool.runInteraction( | |
657 | desc="_rejected_events_metadata_get", func=get_rejected_events | |
658 | ) | |
659 | ||
660 | if not results: | |
661 | await self.db_pool.updates._end_background_update( | |
662 | "rejected_events_metadata" | |
663 | ) | |
664 | return 0 | |
665 | ||
666 | state_events = [] | |
667 | auth_events = [] | |
668 | for event_id, room_version, event_json, has_state, has_event_auth in results: | |
669 | last_event_id = event_id | |
670 | ||
671 | if has_state and has_event_auth: | |
672 | continue | |
673 | ||
674 | room_version_obj = KNOWN_ROOM_VERSIONS.get(room_version) | |
675 | if not room_version_obj: | |
676 | # We no longer support this room version, so we just ignore the | |
677 | # events entirely. | |
678 | logger.info( | |
679 | "Ignoring event with unknown room version %r: %r", | |
680 | room_version, | |
681 | event_id, | |
682 | ) | |
683 | continue | |
684 | ||
685 | event = make_event_from_dict(event_json, room_version_obj) | |
686 | ||
687 | if not event.is_state(): | |
688 | continue | |
689 | ||
690 | if not has_state: | |
691 | state_events.append( | |
692 | { | |
693 | "event_id": event.event_id, | |
694 | "room_id": event.room_id, | |
695 | "type": event.type, | |
696 | "state_key": event.state_key, | |
697 | } | |
698 | ) | |
699 | ||
700 | if not has_event_auth: | |
701 | for auth_id in event.auth_event_ids(): | |
702 | auth_events.append( | |
703 | { | |
704 | "room_id": event.room_id, | |
705 | "event_id": event.event_id, | |
706 | "auth_id": auth_id, | |
707 | } | |
708 | ) | |
709 | ||
710 | if state_events: | |
711 | await self.db_pool.simple_insert_many( | |
712 | table="state_events", | |
713 | values=state_events, | |
714 | desc="_rejected_events_metadata_state_events", | |
715 | ) | |
716 | ||
717 | if auth_events: | |
718 | await self.db_pool.simple_insert_many( | |
719 | table="event_auth", | |
720 | values=auth_events, | |
721 | desc="_rejected_events_metadata_event_auth", | |
722 | ) | |
723 | ||
724 | await self.db_pool.updates._background_update_progress( | |
725 | "rejected_events_metadata", {"last_event_id": last_event_id} | |
726 | ) | |
727 | ||
728 | if len(results) < batch_size: | |
729 | await self.db_pool.updates._end_background_update( | |
730 | "rejected_events_metadata" | |
731 | ) | |
732 | ||
733 | return len(results) | |
734 | ||
735 | async def _chain_cover_index(self, progress: dict, batch_size: int) -> int: | |
736 | """A background updates that iterates over all rooms and generates the | |
737 | chain cover index for them. | |
738 | """ | |
739 | ||
740 | current_room_id = progress.get("current_room_id", "") | |
741 | ||
742 | # Where we've processed up to in the room, defaults to the start of the | |
743 | # room. | |
744 | last_depth = progress.get("last_depth", -1) | |
745 | last_stream = progress.get("last_stream", -1) | |
746 | ||
747 | result = await self.db_pool.runInteraction( | |
748 | "_chain_cover_index", | |
749 | self._calculate_chain_cover_txn, | |
750 | current_room_id, | |
751 | last_depth, | |
752 | last_stream, | |
753 | batch_size, | |
754 | single_room=False, | |
755 | ) | |
756 | ||
757 | finished = result.processed_count == 0 | |
758 | ||
759 | total_rows_processed = result.processed_count | |
760 | current_room_id = result.room_id | |
761 | last_depth = result.depth | |
762 | last_stream = result.stream | |
763 | ||
764 | for room_id, (depth, stream) in result.finished_room_map.items(): | |
765 | # If we've done all the events in the room we flip the | |
766 | # `has_auth_chain_index` in the DB. Note that its possible for | |
767 | # further events to be persisted between the above and setting the | |
768 | # flag without having the chain cover calculated for them. This is | |
769 | # fine as a) the code gracefully handles these cases and b) we'll | |
770 | # calculate them below. | |
771 | ||
772 | await self.db_pool.simple_update( | |
773 | table="rooms", | |
774 | keyvalues={"room_id": room_id}, | |
775 | updatevalues={"has_auth_chain_index": True}, | |
776 | desc="_chain_cover_index", | |
777 | ) | |
778 | ||
779 | # Handle any events that might have raced with us flipping the | |
780 | # bit above. | |
781 | result = await self.db_pool.runInteraction( | |
782 | "_chain_cover_index", | |
783 | self._calculate_chain_cover_txn, | |
784 | room_id, | |
785 | depth, | |
786 | stream, | |
787 | batch_size=None, | |
788 | single_room=True, | |
789 | ) | |
790 | ||
791 | total_rows_processed += result.processed_count | |
792 | ||
793 | if finished: | |
794 | await self.db_pool.updates._end_background_update("chain_cover") | |
795 | return total_rows_processed | |
796 | ||
797 | await self.db_pool.updates._background_update_progress( | |
798 | "chain_cover", | |
799 | { | |
800 | "current_room_id": current_room_id, | |
801 | "last_depth": last_depth, | |
802 | "last_stream": last_stream, | |
803 | }, | |
804 | ) | |
805 | ||
806 | return total_rows_processed | |
807 | ||
808 | def _calculate_chain_cover_txn( | |
809 | self, | |
810 | txn: Cursor, | |
811 | last_room_id: str, | |
812 | last_depth: int, | |
813 | last_stream: int, | |
814 | batch_size: Optional[int], | |
815 | single_room: bool, | |
816 | ) -> _CalculateChainCover: | |
817 | """Calculate the chain cover for `batch_size` events, ordered by | |
818 | `(room_id, depth, stream)`. | |
819 | ||
820 | Args: | |
821 | txn, | |
822 | last_room_id, last_depth, last_stream: The `(room_id, depth, stream)` | |
823 | tuple to fetch results after. | |
824 | batch_size: The maximum number of events to process. If None then | |
825 | no limit. | |
826 | single_room: Whether to calculate the index for just the given | |
827 | room. | |
828 | """ | |
829 | ||
830 | # Get the next set of events in the room (that we haven't already | |
831 | # computed chain cover for). We do this in topological order. | |
832 | ||
833 | # We want to do a `(topological_ordering, stream_ordering) > (?,?)` | |
834 | # comparison, but that is not supported on older SQLite versions | |
835 | tuple_clause, tuple_args = make_tuple_comparison_clause( | |
836 | self.database_engine, | |
837 | [ | |
838 | ("events.room_id", last_room_id), | |
839 | ("topological_ordering", last_depth), | |
840 | ("stream_ordering", last_stream), | |
841 | ], | |
842 | ) | |
843 | ||
844 | extra_clause = "" | |
845 | if single_room: | |
846 | extra_clause = "AND events.room_id = ?" | |
847 | tuple_args.append(last_room_id) | |
848 | ||
849 | sql = """ | |
850 | SELECT | |
851 | event_id, state_events.type, state_events.state_key, | |
852 | topological_ordering, stream_ordering, | |
853 | events.room_id | |
854 | FROM events | |
855 | INNER JOIN state_events USING (event_id) | |
856 | LEFT JOIN event_auth_chains USING (event_id) | |
857 | LEFT JOIN event_auth_chain_to_calculate USING (event_id) | |
858 | WHERE event_auth_chains.event_id IS NULL | |
859 | AND event_auth_chain_to_calculate.event_id IS NULL | |
860 | AND %(tuple_cmp)s | |
861 | %(extra)s | |
862 | ORDER BY events.room_id, topological_ordering, stream_ordering | |
863 | %(limit)s | |
864 | """ % { | |
865 | "tuple_cmp": tuple_clause, | |
866 | "limit": "LIMIT ?" if batch_size is not None else "", | |
867 | "extra": extra_clause, | |
868 | } | |
869 | ||
870 | if batch_size is not None: | |
871 | tuple_args.append(batch_size) | |
872 | ||
873 | txn.execute(sql, tuple_args) | |
874 | rows = txn.fetchall() | |
875 | ||
876 | # Put the results in the necessary format for | |
877 | # `_add_chain_cover_index` | |
878 | event_to_room_id = {row[0]: row[5] for row in rows} | |
879 | event_to_types = {row[0]: (row[1], row[2]) for row in rows} | |
880 | ||
881 | # Calculate the new last position we've processed up to. | |
882 | new_last_depth = rows[-1][3] if rows else last_depth # type: int | |
883 | new_last_stream = rows[-1][4] if rows else last_stream # type: int | |
884 | new_last_room_id = rows[-1][5] if rows else "" # type: str | |
885 | ||
886 | # Map from room_id to last depth/stream_ordering processed for the room, | |
887 | # excluding the last room (which we're likely still processing). We also | |
888 | # need to include the room passed in if it's not included in the result | |
889 | # set (as we then know we've processed all events in said room). | |
890 | # | |
891 | # This is the set of rooms that we can now safely flip the | |
892 | # `has_auth_chain_index` bit for. | |
893 | finished_rooms = { | |
894 | row[5]: (row[3], row[4]) for row in rows if row[5] != new_last_room_id | |
895 | } | |
896 | if last_room_id not in finished_rooms and last_room_id != new_last_room_id: | |
897 | finished_rooms[last_room_id] = (last_depth, last_stream) | |
898 | ||
899 | count = len(rows) | |
900 | ||
901 | # We also need to fetch the auth events for them. | |
902 | auth_events = self.db_pool.simple_select_many_txn( | |
903 | txn, | |
904 | table="event_auth", | |
905 | column="event_id", | |
906 | iterable=event_to_room_id, | |
907 | keyvalues={}, | |
908 | retcols=("event_id", "auth_id"), | |
909 | ) | |
910 | ||
911 | event_to_auth_chain = {} # type: Dict[str, List[str]] | |
912 | for row in auth_events: | |
913 | event_to_auth_chain.setdefault(row["event_id"], []).append(row["auth_id"]) | |
914 | ||
915 | # Calculate and persist the chain cover index for this set of events. | |
916 | # | |
917 | # Annoyingly we need to gut wrench into the persit event store so that | |
918 | # we can reuse the function to calculate the chain cover for rooms. | |
919 | PersistEventsStore._add_chain_cover_index( | |
920 | txn, self.db_pool, event_to_room_id, event_to_types, event_to_auth_chain, | |
921 | ) | |
922 | ||
923 | return _CalculateChainCover( | |
924 | room_id=new_last_room_id, | |
925 | depth=new_last_depth, | |
926 | stream=new_last_stream, | |
927 | processed_count=count, | |
928 | finished_room_map=finished_rooms, | |
929 | ) |
95 | 95 | db=database, |
96 | 96 | stream_name="events", |
97 | 97 | instance_name=hs.get_instance_name(), |
98 | table="events", | |
99 | instance_column="instance_name", | |
100 | id_column="stream_ordering", | |
98 | tables=[("events", "instance_name", "stream_ordering")], | |
101 | 99 | sequence_name="events_stream_seq", |
102 | 100 | writers=hs.config.worker.writers.events, |
103 | 101 | ) |
106 | 104 | db=database, |
107 | 105 | stream_name="backfill", |
108 | 106 | instance_name=hs.get_instance_name(), |
109 | table="events", | |
110 | instance_column="instance_name", | |
111 | id_column="stream_ordering", | |
107 | tables=[("events", "instance_name", "stream_ordering")], | |
112 | 108 | sequence_name="events_backfill_stream_seq", |
113 | 109 | positive=False, |
114 | 110 | writers=hs.config.worker.writers.events, |
0 | 0 | # -*- coding: utf-8 -*- |
1 | 1 | # Copyright 2014-2016 OpenMarket Ltd |
2 | # Copyright 2020-2021 The Matrix.org Foundation C.I.C. | |
2 | 3 | # |
3 | 4 | # Licensed under the Apache License, Version 2.0 (the "License"); |
4 | 5 | # you may not use this file except in compliance with the License. |
168 | 169 | |
169 | 170 | async def get_local_media_before( |
170 | 171 | self, before_ts: int, size_gt: int, keep_profiles: bool, |
171 | ) -> Optional[List[str]]: | |
172 | ) -> List[str]: | |
172 | 173 | |
173 | 174 | # to find files that have never been accessed (last_access_ts IS NULL) |
174 | 175 | # compare with `created_ts` |
81 | 81 | ) |
82 | 82 | |
83 | 83 | async def set_profile_avatar_url( |
84 | self, user_localpart: str, new_avatar_url: str | |
84 | self, user_localpart: str, new_avatar_url: Optional[str] | |
85 | 85 | ) -> None: |
86 | 86 | await self.db_pool.simple_update_one( |
87 | 87 | table="profiles", |
16 | 16 | import logging |
17 | 17 | from typing import TYPE_CHECKING, Any, Dict, Iterable, Iterator, List, Optional, Tuple |
18 | 18 | |
19 | from canonicaljson import encode_canonical_json | |
20 | ||
21 | 19 | from synapse.push import PusherConfig, ThrottleParams |
22 | 20 | from synapse.storage._base import SQLBaseStore, db_to_json |
23 | 21 | from synapse.storage.database import DatabasePool |
24 | 22 | from synapse.storage.types import Connection |
25 | 23 | from synapse.storage.util.id_generators import StreamIdGenerator |
26 | 24 | from synapse.types import JsonDict |
25 | from synapse.util import json_encoder | |
27 | 26 | from synapse.util.caches.descriptors import cached, cachedList |
28 | 27 | |
29 | 28 | if TYPE_CHECKING: |
314 | 313 | "device_display_name": device_display_name, |
315 | 314 | "ts": pushkey_ts, |
316 | 315 | "lang": lang, |
317 | "data": bytearray(encode_canonical_json(data)), | |
316 | "data": json_encoder.encode(data), | |
318 | 317 | "last_stream_ordering": last_stream_ordering, |
319 | 318 | "profile_tag": profile_tag, |
320 | 319 | "id": stream_id, |
13 | 13 | # See the License for the specific language governing permissions and |
14 | 14 | # limitations under the License. |
15 | 15 | |
16 | import abc | |
17 | 16 | import logging |
18 | 17 | from typing import Any, Dict, List, Optional, Tuple |
19 | 18 | |
20 | 19 | from twisted.internet import defer |
21 | 20 | |
21 | from synapse.replication.slave.storage._slaved_id_tracker import SlavedIdTracker | |
22 | from synapse.replication.tcp.streams import ReceiptsStream | |
22 | 23 | from synapse.storage._base import SQLBaseStore, db_to_json, make_in_list_sql_clause |
23 | 24 | from synapse.storage.database import DatabasePool |
24 | from synapse.storage.util.id_generators import StreamIdGenerator | |
25 | from synapse.storage.engines import PostgresEngine | |
26 | from synapse.storage.util.id_generators import MultiWriterIdGenerator, StreamIdGenerator | |
25 | 27 | from synapse.types import JsonDict |
26 | 28 | from synapse.util import json_encoder |
27 | 29 | from synapse.util.caches.descriptors import cached, cachedList |
30 | 32 | logger = logging.getLogger(__name__) |
31 | 33 | |
32 | 34 | |
33 | # The ABCMeta metaclass ensures that it cannot be instantiated without | |
34 | # the abstract methods being implemented. | |
35 | class ReceiptsWorkerStore(SQLBaseStore, metaclass=abc.ABCMeta): | |
36 | """This is an abstract base class where subclasses must implement | |
37 | `get_max_receipt_stream_id` which can be called in the initializer. | |
38 | """ | |
39 | ||
35 | class ReceiptsWorkerStore(SQLBaseStore): | |
40 | 36 | def __init__(self, database: DatabasePool, db_conn, hs): |
37 | self._instance_name = hs.get_instance_name() | |
38 | ||
39 | if isinstance(database.engine, PostgresEngine): | |
40 | self._can_write_to_receipts = ( | |
41 | self._instance_name in hs.config.worker.writers.receipts | |
42 | ) | |
43 | ||
44 | self._receipts_id_gen = MultiWriterIdGenerator( | |
45 | db_conn=db_conn, | |
46 | db=database, | |
47 | stream_name="receipts", | |
48 | instance_name=self._instance_name, | |
49 | tables=[("receipts_linearized", "instance_name", "stream_id")], | |
50 | sequence_name="receipts_sequence", | |
51 | writers=hs.config.worker.writers.receipts, | |
52 | ) | |
53 | else: | |
54 | self._can_write_to_receipts = True | |
55 | ||
56 | # We shouldn't be running in worker mode with SQLite, but its useful | |
57 | # to support it for unit tests. | |
58 | # | |
59 | # If this process is the writer than we need to use | |
60 | # `StreamIdGenerator`, otherwise we use `SlavedIdTracker` which gets | |
61 | # updated over replication. (Multiple writers are not supported for | |
62 | # SQLite). | |
63 | if hs.get_instance_name() in hs.config.worker.writers.receipts: | |
64 | self._receipts_id_gen = StreamIdGenerator( | |
65 | db_conn, "receipts_linearized", "stream_id" | |
66 | ) | |
67 | else: | |
68 | self._receipts_id_gen = SlavedIdTracker( | |
69 | db_conn, "receipts_linearized", "stream_id" | |
70 | ) | |
71 | ||
41 | 72 | super().__init__(database, db_conn, hs) |
42 | 73 | |
43 | 74 | self._receipts_stream_cache = StreamChangeCache( |
44 | 75 | "ReceiptsRoomChangeCache", self.get_max_receipt_stream_id() |
45 | 76 | ) |
46 | 77 | |
47 | @abc.abstractmethod | |
48 | 78 | def get_max_receipt_stream_id(self): |
49 | 79 | """Get the current max stream ID for receipts stream |
50 | 80 | |
51 | 81 | Returns: |
52 | 82 | int |
53 | 83 | """ |
54 | raise NotImplementedError() | |
84 | return self._receipts_id_gen.get_current_token() | |
55 | 85 | |
56 | 86 | @cached() |
57 | 87 | async def get_users_with_read_receipts_in_room(self, room_id): |
427 | 457 | |
428 | 458 | self.get_users_with_read_receipts_in_room.invalidate((room_id,)) |
429 | 459 | |
430 | ||
431 | class ReceiptsStore(ReceiptsWorkerStore): | |
432 | def __init__(self, database: DatabasePool, db_conn, hs): | |
433 | # We instantiate this first as the ReceiptsWorkerStore constructor | |
434 | # needs to be able to call get_max_receipt_stream_id | |
435 | self._receipts_id_gen = StreamIdGenerator( | |
436 | db_conn, "receipts_linearized", "stream_id" | |
437 | ) | |
438 | ||
439 | super().__init__(database, db_conn, hs) | |
440 | ||
441 | def get_max_receipt_stream_id(self): | |
442 | return self._receipts_id_gen.get_current_token() | |
460 | def invalidate_caches_for_receipt(self, room_id, receipt_type, user_id): | |
461 | self.get_receipts_for_user.invalidate((user_id, receipt_type)) | |
462 | self._get_linearized_receipts_for_room.invalidate_many((room_id,)) | |
463 | self.get_last_receipt_event_id_for_user.invalidate( | |
464 | (user_id, room_id, receipt_type) | |
465 | ) | |
466 | self._invalidate_get_users_with_receipts_in_room(room_id, receipt_type, user_id) | |
467 | self.get_receipts_for_room.invalidate((room_id, receipt_type)) | |
468 | ||
469 | def process_replication_rows(self, stream_name, instance_name, token, rows): | |
470 | if stream_name == ReceiptsStream.NAME: | |
471 | self._receipts_id_gen.advance(instance_name, token) | |
472 | for row in rows: | |
473 | self.invalidate_caches_for_receipt( | |
474 | row.room_id, row.receipt_type, row.user_id | |
475 | ) | |
476 | self._receipts_stream_cache.entity_has_changed(row.room_id, token) | |
477 | ||
478 | return super().process_replication_rows(stream_name, instance_name, token, rows) | |
443 | 479 | |
444 | 480 | def insert_linearized_receipt_txn( |
445 | 481 | self, txn, room_id, receipt_type, user_id, event_id, data, stream_id |
451 | 487 | otherwise, the rx timestamp of the event that the RR corresponds to |
452 | 488 | (or 0 if the event is unknown) |
453 | 489 | """ |
490 | assert self._can_write_to_receipts | |
491 | ||
454 | 492 | res = self.db_pool.simple_select_one_txn( |
455 | 493 | txn, |
456 | 494 | table="events", |
482 | 520 | ) |
483 | 521 | return None |
484 | 522 | |
485 | txn.call_after(self.get_receipts_for_room.invalidate, (room_id, receipt_type)) | |
486 | 523 | txn.call_after( |
487 | self._invalidate_get_users_with_receipts_in_room, | |
488 | room_id, | |
489 | receipt_type, | |
490 | user_id, | |
491 | ) | |
492 | txn.call_after(self.get_receipts_for_user.invalidate, (user_id, receipt_type)) | |
493 | # FIXME: This shouldn't invalidate the whole cache | |
494 | txn.call_after( | |
495 | self._get_linearized_receipts_for_room.invalidate_many, (room_id,) | |
524 | self.invalidate_caches_for_receipt, room_id, receipt_type, user_id | |
496 | 525 | ) |
497 | 526 | |
498 | 527 | txn.call_after( |
499 | 528 | self._receipts_stream_cache.entity_has_changed, room_id, stream_id |
500 | ) | |
501 | ||
502 | txn.call_after( | |
503 | self.get_last_receipt_event_id_for_user.invalidate, | |
504 | (user_id, room_id, receipt_type), | |
505 | 529 | ) |
506 | 530 | |
507 | 531 | self.db_pool.simple_upsert_txn( |
542 | 566 | Automatically does conversion between linearized and graph |
543 | 567 | representations. |
544 | 568 | """ |
569 | assert self._can_write_to_receipts | |
570 | ||
545 | 571 | if not event_ids: |
546 | 572 | return None |
547 | 573 | |
606 | 632 | async def insert_graph_receipt( |
607 | 633 | self, room_id, receipt_type, user_id, event_ids, data |
608 | 634 | ): |
635 | assert self._can_write_to_receipts | |
636 | ||
609 | 637 | return await self.db_pool.runInteraction( |
610 | 638 | "insert_graph_receipt", |
611 | 639 | self.insert_graph_receipt_txn, |
619 | 647 | def insert_graph_receipt_txn( |
620 | 648 | self, txn, room_id, receipt_type, user_id, event_ids, data |
621 | 649 | ): |
650 | assert self._can_write_to_receipts | |
651 | ||
622 | 652 | txn.call_after(self.get_receipts_for_room.invalidate, (room_id, receipt_type)) |
623 | 653 | txn.call_after( |
624 | 654 | self._invalidate_get_users_with_receipts_in_room, |
652 | 682 | "data": json_encoder.encode(data), |
653 | 683 | }, |
654 | 684 | ) |
685 | ||
686 | ||
687 | class ReceiptsStore(ReceiptsWorkerStore): | |
688 | pass |
15 | 15 | |
16 | 16 | import collections |
17 | 17 | import logging |
18 | import re | |
19 | 18 | from abc import abstractmethod |
20 | 19 | from enum import Enum |
21 | 20 | from typing import Any, Dict, List, Optional, Tuple |
29 | 28 | from synapse.types import JsonDict, ThirdPartyInstanceID |
30 | 29 | from synapse.util import json_encoder |
31 | 30 | from synapse.util.caches.descriptors import cached |
31 | from synapse.util.stringutils import MXC_REGEX | |
32 | 32 | |
33 | 33 | logger = logging.getLogger(__name__) |
34 | 34 | |
83 | 83 | return await self.db_pool.simple_select_one( |
84 | 84 | table="rooms", |
85 | 85 | keyvalues={"room_id": room_id}, |
86 | retcols=("room_id", "is_public", "creator"), | |
86 | retcols=("room_id", "is_public", "creator", "has_auth_chain_index"), | |
87 | 87 | desc="get_room", |
88 | 88 | allow_none=True, |
89 | 89 | ) |
659 | 659 | The local and remote media as a lists of tuples where the key is |
660 | 660 | the hostname and the value is the media ID. |
661 | 661 | """ |
662 | mxc_re = re.compile("^mxc://([^/]+)/([^/#?]+)") | |
663 | ||
664 | 662 | sql = """ |
665 | 663 | SELECT stream_ordering, json FROM events |
666 | 664 | JOIN event_json USING (room_id, event_id) |
687 | 685 | for url in (content_url, thumbnail_url): |
688 | 686 | if not url: |
689 | 687 | continue |
690 | matches = mxc_re.match(url) | |
688 | matches = MXC_REGEX.match(url) | |
691 | 689 | if matches: |
692 | 690 | hostname = matches.group(1) |
693 | 691 | media_id = matches.group(2) |
1165 | 1163 | # It's overridden by RoomStore for the synapse master. |
1166 | 1164 | raise NotImplementedError() |
1167 | 1165 | |
1166 | async def has_auth_chain_index(self, room_id: str) -> bool: | |
1167 | """Check if the room has (or can have) a chain cover index. | |
1168 | ||
1169 | Defaults to True if we don't have an entry in `rooms` table nor any | |
1170 | events for the room. | |
1171 | """ | |
1172 | ||
1173 | has_auth_chain_index = await self.db_pool.simple_select_one_onecol( | |
1174 | table="rooms", | |
1175 | keyvalues={"room_id": room_id}, | |
1176 | retcol="has_auth_chain_index", | |
1177 | desc="has_auth_chain_index", | |
1178 | allow_none=True, | |
1179 | ) | |
1180 | ||
1181 | if has_auth_chain_index: | |
1182 | return True | |
1183 | ||
1184 | # It's possible that we already have events for the room in our DB | |
1185 | # without a corresponding room entry. If we do then we don't want to | |
1186 | # mark the room as having an auth chain cover index. | |
1187 | max_ordering = await self.db_pool.simple_select_one_onecol( | |
1188 | table="events", | |
1189 | keyvalues={"room_id": room_id}, | |
1190 | retcol="MAX(stream_ordering)", | |
1191 | allow_none=True, | |
1192 | desc="upsert_room_on_join", | |
1193 | ) | |
1194 | ||
1195 | return max_ordering is None | |
1196 | ||
1168 | 1197 | |
1169 | 1198 | class RoomStore(RoomBackgroundUpdateStore, RoomWorkerStore, SearchStore): |
1170 | 1199 | def __init__(self, database: DatabasePool, db_conn, hs): |
1178 | 1207 | Called when we join a room over federation, and overwrites any room version |
1179 | 1208 | currently in the table. |
1180 | 1209 | """ |
1210 | # It's possible that we already have events for the room in our DB | |
1211 | # without a corresponding room entry. If we do then we don't want to | |
1212 | # mark the room as having an auth chain cover index. | |
1213 | has_auth_chain_index = await self.has_auth_chain_index(room_id) | |
1214 | ||
1181 | 1215 | await self.db_pool.simple_upsert( |
1182 | 1216 | desc="upsert_room_on_join", |
1183 | 1217 | table="rooms", |
1184 | 1218 | keyvalues={"room_id": room_id}, |
1185 | 1219 | values={"room_version": room_version.identifier}, |
1186 | insertion_values={"is_public": False, "creator": ""}, | |
1220 | insertion_values={ | |
1221 | "is_public": False, | |
1222 | "creator": "", | |
1223 | "has_auth_chain_index": has_auth_chain_index, | |
1224 | }, | |
1187 | 1225 | # rooms has a unique constraint on room_id, so no need to lock when doing an |
1188 | 1226 | # emulated upsert. |
1189 | 1227 | lock=False, |
1218 | 1256 | "creator": room_creator_user_id, |
1219 | 1257 | "is_public": is_public, |
1220 | 1258 | "room_version": room_version.identifier, |
1259 | "has_auth_chain_index": True, | |
1221 | 1260 | }, |
1222 | 1261 | ) |
1223 | 1262 | if is_public: |
1246 | 1285 | When we receive an invite or any other event over federation that may relate to a room |
1247 | 1286 | we are not in, store the version of the room if we don't already know the room version. |
1248 | 1287 | """ |
1288 | # It's possible that we already have events for the room in our DB | |
1289 | # without a corresponding room entry. If we do then we don't want to | |
1290 | # mark the room as having an auth chain cover index. | |
1291 | has_auth_chain_index = await self.has_auth_chain_index(room_id) | |
1292 | ||
1249 | 1293 | await self.db_pool.simple_upsert( |
1250 | 1294 | desc="maybe_store_room_on_outlier_membership", |
1251 | 1295 | table="rooms", |
1255 | 1299 | "room_version": room_version.identifier, |
1256 | 1300 | "is_public": False, |
1257 | 1301 | "creator": "", |
1302 | "has_auth_chain_index": has_auth_chain_index, | |
1258 | 1303 | }, |
1259 | 1304 | # rooms has a unique constraint on room_id, so no need to lock when doing an |
1260 | 1305 | # emulated upsert. |
+16
-0
0 | /* Copyright 2020 The Matrix.org Foundation C.I.C | |
1 | * | |
2 | * Licensed under the Apache License, Version 2.0 (the "License"); | |
3 | * you may not use this file except in compliance with the License. | |
4 | * You may obtain a copy of the License at | |
5 | * | |
6 | * http://www.apache.org/licenses/LICENSE-2.0 | |
7 | * | |
8 | * Unless required by applicable law or agreed to in writing, software | |
9 | * distributed under the License is distributed on an "AS IS" BASIS, | |
10 | * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |
11 | * See the License for the specific language governing permissions and | |
12 | * limitations under the License. | |
13 | */ | |
14 | ||
15 | ALTER TABLE access_tokens DROP COLUMN last_used;⏎ |
0 | /* | |
1 | * Copyright 2020 The Matrix.org Foundation C.I.C. | |
2 | * | |
3 | * Licensed under the Apache License, Version 2.0 (the "License"); | |
4 | * you may not use this file except in compliance with the License. | |
5 | * You may obtain a copy of the License at | |
6 | * | |
7 | * http://www.apache.org/licenses/LICENSE-2.0 | |
8 | * | |
9 | * Unless required by applicable law or agreed to in writing, software | |
10 | * distributed under the License is distributed on an "AS IS" BASIS, | |
11 | * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |
12 | * See the License for the specific language governing permissions and | |
13 | * limitations under the License. | |
14 | */ | |
15 | ||
16 | -- Dropping last_used column from access_tokens table. | |
17 | ||
18 | CREATE TABLE access_tokens2 ( | |
19 | id BIGINT PRIMARY KEY, | |
20 | user_id TEXT NOT NULL, | |
21 | device_id TEXT, | |
22 | token TEXT NOT NULL, | |
23 | valid_until_ms BIGINT, | |
24 | puppets_user_id TEXT, | |
25 | last_validated BIGINT, | |
26 | UNIQUE(token) | |
27 | ); | |
28 | ||
29 | INSERT INTO access_tokens2(id, user_id, device_id, token) | |
30 | SELECT id, user_id, device_id, token FROM access_tokens; | |
31 | ||
32 | DROP TABLE access_tokens; | |
33 | ALTER TABLE access_tokens2 RENAME TO access_tokens; | |
34 | ||
35 | CREATE INDEX access_tokens_device_id ON access_tokens (user_id, device_id); | |
36 | ||
37 | ||
38 | -- Re-adding foreign key reference in event_txn_id table | |
39 | ||
40 | CREATE TABLE event_txn_id2 ( | |
41 | event_id TEXT NOT NULL, | |
42 | room_id TEXT NOT NULL, | |
43 | user_id TEXT NOT NULL, | |
44 | token_id BIGINT NOT NULL, | |
45 | txn_id TEXT NOT NULL, | |
46 | inserted_ts BIGINT NOT NULL, | |
47 | FOREIGN KEY (event_id) | |
48 | REFERENCES events (event_id) ON DELETE CASCADE, | |
49 | FOREIGN KEY (token_id) | |
50 | REFERENCES access_tokens (id) ON DELETE CASCADE | |
51 | ); | |
52 | ||
53 | INSERT INTO event_txn_id2(event_id, room_id, user_id, token_id, txn_id, inserted_ts) | |
54 | SELECT event_id, room_id, user_id, token_id, txn_id, inserted_ts FROM event_txn_id; | |
55 | ||
56 | DROP TABLE event_txn_id; | |
57 | ALTER TABLE event_txn_id2 RENAME TO event_txn_id; | |
58 | ||
59 | CREATE UNIQUE INDEX IF NOT EXISTS event_txn_id_event_id ON event_txn_id(event_id); | |
60 | CREATE UNIQUE INDEX IF NOT EXISTS event_txn_id_txn_id ON event_txn_id(room_id, user_id, token_id, txn_id); | |
61 | CREATE INDEX IF NOT EXISTS event_txn_id_ts ON event_txn_id(inserted_ts);⏎ |
0 | /* Copyright 2020 The Matrix.org Foundation C.I.C | |
1 | * | |
2 | * Licensed under the Apache License, Version 2.0 (the "License"); | |
3 | * you may not use this file except in compliance with the License. | |
4 | * You may obtain a copy of the License at | |
5 | * | |
6 | * http://www.apache.org/licenses/LICENSE-2.0 | |
7 | * | |
8 | * Unless required by applicable law or agreed to in writing, software | |
9 | * distributed under the License is distributed on an "AS IS" BASIS, | |
10 | * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |
11 | * See the License for the specific language governing permissions and | |
12 | * limitations under the License. | |
13 | */ | |
14 | ||
15 | INSERT INTO background_updates (ordering, update_name, progress_json) VALUES | |
16 | (5828, 'rejected_events_metadata', '{}'); |
0 | # Copyright 2021 The Matrix.org Foundation C.I.C. | |
1 | # | |
2 | # Licensed under the Apache License, Version 2.0 (the "License"); | |
3 | # you may not use this file except in compliance with the License. | |
4 | # You may obtain a copy of the License at | |
5 | # | |
6 | # http://www.apache.org/licenses/LICENSE-2.0 | |
7 | # | |
8 | # Unless required by applicable law or agreed to in writing, software | |
9 | # distributed under the License is distributed on an "AS IS" BASIS, | |
10 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |
11 | # See the License for the specific language governing permissions and | |
12 | # limitations under the License. | |
13 | ||
14 | """ | |
15 | This migration denormalises the account_data table into an ignored users table. | |
16 | """ | |
17 | ||
18 | import logging | |
19 | from io import StringIO | |
20 | ||
21 | from synapse.storage._base import db_to_json | |
22 | from synapse.storage.engines import BaseDatabaseEngine | |
23 | from synapse.storage.prepare_database import execute_statements_from_stream | |
24 | from synapse.storage.types import Cursor | |
25 | ||
26 | logger = logging.getLogger(__name__) | |
27 | ||
28 | ||
29 | def run_upgrade(cur: Cursor, database_engine: BaseDatabaseEngine, *args, **kwargs): | |
30 | pass | |
31 | ||
32 | ||
33 | def run_create(cur: Cursor, database_engine: BaseDatabaseEngine, *args, **kwargs): | |
34 | logger.info("Creating ignored_users table") | |
35 | execute_statements_from_stream(cur, StringIO(_create_commands)) | |
36 | ||
37 | # We now upgrade existing data, if any. We don't do this in `run_upgrade` as | |
38 | # we a) want to run these before adding constraints and b) `run_upgrade` is | |
39 | # not run on empty databases. | |
40 | insert_sql = """ | |
41 | INSERT INTO ignored_users (ignorer_user_id, ignored_user_id) VALUES (?, ?) | |
42 | """ | |
43 | ||
44 | logger.info("Converting existing ignore lists") | |
45 | cur.execute( | |
46 | "SELECT user_id, content FROM account_data WHERE account_data_type = 'm.ignored_user_list'" | |
47 | ) | |
48 | for user_id, content_json in cur.fetchall(): | |
49 | content = db_to_json(content_json) | |
50 | ||
51 | # The content should be the form of a dictionary with a key | |
52 | # "ignored_users" pointing to a dictionary with keys of ignored users. | |
53 | # | |
54 | # { "ignored_users": "@someone:example.org": {} } | |
55 | ignored_users = content.get("ignored_users", {}) | |
56 | if isinstance(ignored_users, dict) and ignored_users: | |
57 | cur.executemany(insert_sql, [(user_id, u) for u in ignored_users]) | |
58 | ||
59 | # Add indexes after inserting data for efficiency. | |
60 | logger.info("Adding constraints to ignored_users table") | |
61 | execute_statements_from_stream(cur, StringIO(_constraints_commands)) | |
62 | ||
63 | ||
64 | # there might be duplicates, so the easiest way to achieve this is to create a new | |
65 | # table with the right data, and renaming it into place | |
66 | ||
67 | _create_commands = """ | |
68 | -- Users which are ignored when calculating push notifications. This data is | |
69 | -- denormalized from account data. | |
70 | CREATE TABLE IF NOT EXISTS ignored_users( | |
71 | ignorer_user_id TEXT NOT NULL, -- The user ID of the user who is ignoring another user. (This is a local user.) | |
72 | ignored_user_id TEXT NOT NULL -- The user ID of the user who is being ignored. (This is a local or remote user.) | |
73 | ); | |
74 | """ | |
75 | ||
76 | _constraints_commands = """ | |
77 | CREATE UNIQUE INDEX ignored_users_uniqueness ON ignored_users (ignorer_user_id, ignored_user_id); | |
78 | ||
79 | -- Add an index on ignored_users since look-ups are done to get all ignorers of an ignored user. | |
80 | CREATE INDEX ignored_users_ignored_user_id ON ignored_users (ignored_user_id); | |
81 | """ |
0 | /* Copyright 2021 The Matrix.org Foundation C.I.C | |
1 | * | |
2 | * Licensed under the Apache License, Version 2.0 (the "License"); | |
3 | * you may not use this file except in compliance with the License. | |
4 | * You may obtain a copy of the License at | |
5 | * | |
6 | * http://www.apache.org/licenses/LICENSE-2.0 | |
7 | * | |
8 | * Unless required by applicable law or agreed to in writing, software | |
9 | * distributed under the License is distributed on an "AS IS" BASIS, | |
10 | * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |
11 | * See the License for the specific language governing permissions and | |
12 | * limitations under the License. | |
13 | */ | |
14 | ||
15 | ALTER TABLE device_inbox ADD COLUMN instance_name TEXT; | |
16 | ALTER TABLE device_federation_inbox ADD COLUMN instance_name TEXT; | |
17 | ALTER TABLE device_federation_outbox ADD COLUMN instance_name TEXT; |
+25
-0
0 | /* Copyright 2021 The Matrix.org Foundation C.I.C | |
1 | * | |
2 | * Licensed under the Apache License, Version 2.0 (the "License"); | |
3 | * you may not use this file except in compliance with the License. | |
4 | * You may obtain a copy of the License at | |
5 | * | |
6 | * http://www.apache.org/licenses/LICENSE-2.0 | |
7 | * | |
8 | * Unless required by applicable law or agreed to in writing, software | |
9 | * distributed under the License is distributed on an "AS IS" BASIS, | |
10 | * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |
11 | * See the License for the specific language governing permissions and | |
12 | * limitations under the License. | |
13 | */ | |
14 | ||
15 | CREATE SEQUENCE IF NOT EXISTS device_inbox_sequence; | |
16 | ||
17 | -- We need to take the max across both device_inbox and device_federation_outbox | |
18 | -- tables as they share the ID generator | |
19 | SELECT setval('device_inbox_sequence', ( | |
20 | SELECT GREATEST( | |
21 | (SELECT COALESCE(MAX(stream_id), 1) FROM device_inbox), | |
22 | (SELECT COALESCE(MAX(stream_id), 1) FROM device_federation_outbox) | |
23 | ) | |
24 | )); |
0 | /* Copyright 2020 The Matrix.org Foundation C.I.C | |
1 | * | |
2 | * Licensed under the Apache License, Version 2.0 (the "License"); | |
3 | * you may not use this file except in compliance with the License. | |
4 | * You may obtain a copy of the License at | |
5 | * | |
6 | * http://www.apache.org/licenses/LICENSE-2.0 | |
7 | * | |
8 | * Unless required by applicable law or agreed to in writing, software | |
9 | * distributed under the License is distributed on an "AS IS" BASIS, | |
10 | * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |
11 | * See the License for the specific language governing permissions and | |
12 | * limitations under the License. | |
13 | */ | |
14 | ||
15 | -- See docs/auth_chain_difference_algorithm.md | |
16 | ||
17 | CREATE TABLE event_auth_chains ( | |
18 | event_id TEXT PRIMARY KEY, | |
19 | chain_id BIGINT NOT NULL, | |
20 | sequence_number BIGINT NOT NULL | |
21 | ); | |
22 | ||
23 | CREATE UNIQUE INDEX event_auth_chains_c_seq_index ON event_auth_chains (chain_id, sequence_number); | |
24 | ||
25 | ||
26 | CREATE TABLE event_auth_chain_links ( | |
27 | origin_chain_id BIGINT NOT NULL, | |
28 | origin_sequence_number BIGINT NOT NULL, | |
29 | ||
30 | target_chain_id BIGINT NOT NULL, | |
31 | target_sequence_number BIGINT NOT NULL | |
32 | ); | |
33 | ||
34 | ||
35 | CREATE INDEX event_auth_chain_links_idx ON event_auth_chain_links (origin_chain_id, target_chain_id); | |
36 | ||
37 | ||
38 | -- Events that we have persisted but not calculated auth chains for, | |
39 | -- e.g. out of band memberships (where we don't have the auth chain) | |
40 | CREATE TABLE event_auth_chain_to_calculate ( | |
41 | event_id TEXT PRIMARY KEY, | |
42 | room_id TEXT NOT NULL, | |
43 | type TEXT NOT NULL, | |
44 | state_key TEXT NOT NULL | |
45 | ); | |
46 | ||
47 | CREATE INDEX event_auth_chain_to_calculate_rm_id ON event_auth_chain_to_calculate(room_id); | |
48 | ||
49 | ||
50 | -- Whether we've calculated the above index for a room. | |
51 | ALTER TABLE rooms ADD COLUMN has_auth_chain_index BOOLEAN; |
0 | /* Copyright 2020 The Matrix.org Foundation C.I.C | |
1 | * | |
2 | * Licensed under the Apache License, Version 2.0 (the "License"); | |
3 | * you may not use this file except in compliance with the License. | |
4 | * You may obtain a copy of the License at | |
5 | * | |
6 | * http://www.apache.org/licenses/LICENSE-2.0 | |
7 | * | |
8 | * Unless required by applicable law or agreed to in writing, software | |
9 | * distributed under the License is distributed on an "AS IS" BASIS, | |
10 | * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |
11 | * See the License for the specific language governing permissions and | |
12 | * limitations under the License. | |
13 | */ | |
14 | ||
15 | CREATE SEQUENCE IF NOT EXISTS event_auth_chain_id; |
0 | /* Copyright 2021 The Matrix.org Foundation C.I.C | |
1 | * | |
2 | * Licensed under the Apache License, Version 2.0 (the "License"); | |
3 | * you may not use this file except in compliance with the License. | |
4 | * You may obtain a copy of the License at | |
5 | * | |
6 | * http://www.apache.org/licenses/LICENSE-2.0 | |
7 | * | |
8 | * Unless required by applicable law or agreed to in writing, software | |
9 | * distributed under the License is distributed on an "AS IS" BASIS, | |
10 | * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |
11 | * See the License for the specific language governing permissions and | |
12 | * limitations under the License. | |
13 | */ | |
14 | ||
15 | -- This is no longer used and was only kept until we bumped the schema version. | |
16 | DROP TABLE IF EXISTS account_data_max_stream_id; |
0 | /* Copyright 2021 The Matrix.org Foundation C.I.C | |
1 | * | |
2 | * Licensed under the Apache License, Version 2.0 (the "License"); | |
3 | * you may not use this file except in compliance with the License. | |
4 | * You may obtain a copy of the License at | |
5 | * | |
6 | * http://www.apache.org/licenses/LICENSE-2.0 | |
7 | * | |
8 | * Unless required by applicable law or agreed to in writing, software | |
9 | * distributed under the License is distributed on an "AS IS" BASIS, | |
10 | * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |
11 | * See the License for the specific language governing permissions and | |
12 | * limitations under the License. | |
13 | */ | |
14 | ||
15 | -- This is no longer used and was only kept until we bumped the schema version. | |
16 | DROP TABLE IF EXISTS cache_invalidation_stream; |
0 | /* Copyright 2020 The Matrix.org Foundation C.I.C | |
1 | * | |
2 | * Licensed under the Apache License, Version 2.0 (the "License"); | |
3 | * you may not use this file except in compliance with the License. | |
4 | * You may obtain a copy of the License at | |
5 | * | |
6 | * http://www.apache.org/licenses/LICENSE-2.0 | |
7 | * | |
8 | * Unless required by applicable law or agreed to in writing, software | |
9 | * distributed under the License is distributed on an "AS IS" BASIS, | |
10 | * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |
11 | * See the License for the specific language governing permissions and | |
12 | * limitations under the License. | |
13 | */ | |
14 | ||
15 | INSERT INTO background_updates (ordering, update_name, progress_json, depends_on) VALUES | |
16 | (5906, 'chain_cover', '{}', 'rejected_events_metadata'); |
0 | /* Copyright 2021 The Matrix.org Foundation C.I.C | |
1 | * | |
2 | * Licensed under the Apache License, Version 2.0 (the "License"); | |
3 | * you may not use this file except in compliance with the License. | |
4 | * You may obtain a copy of the License at | |
5 | * | |
6 | * http://www.apache.org/licenses/LICENSE-2.0 | |
7 | * | |
8 | * Unless required by applicable law or agreed to in writing, software | |
9 | * distributed under the License is distributed on an "AS IS" BASIS, | |
10 | * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |
11 | * See the License for the specific language governing permissions and | |
12 | * limitations under the License. | |
13 | */ | |
14 | ||
15 | ALTER TABLE room_account_data ADD COLUMN instance_name TEXT; | |
16 | ALTER TABLE room_tags_revisions ADD COLUMN instance_name TEXT; | |
17 | ALTER TABLE account_data ADD COLUMN instance_name TEXT; | |
18 | ||
19 | ALTER TABLE receipts_linearized ADD COLUMN instance_name TEXT; |
0 | /* Copyright 2021 The Matrix.org Foundation C.I.C | |
1 | * | |
2 | * Licensed under the Apache License, Version 2.0 (the "License"); | |
3 | * you may not use this file except in compliance with the License. | |
4 | * You may obtain a copy of the License at | |
5 | * | |
6 | * http://www.apache.org/licenses/LICENSE-2.0 | |
7 | * | |
8 | * Unless required by applicable law or agreed to in writing, software | |
9 | * distributed under the License is distributed on an "AS IS" BASIS, | |
10 | * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |
11 | * See the License for the specific language governing permissions and | |
12 | * limitations under the License. | |
13 | */ | |
14 | ||
15 | CREATE SEQUENCE IF NOT EXISTS account_data_sequence; | |
16 | ||
17 | -- We need to take the max across all the account_data tables as they share the | |
18 | -- ID generator | |
19 | SELECT setval('account_data_sequence', ( | |
20 | SELECT GREATEST( | |
21 | (SELECT COALESCE(MAX(stream_id), 1) FROM room_account_data), | |
22 | (SELECT COALESCE(MAX(stream_id), 1) FROM room_tags_revisions), | |
23 | (SELECT COALESCE(MAX(stream_id), 1) FROM account_data) | |
24 | ) | |
25 | )); | |
26 | ||
27 | CREATE SEQUENCE IF NOT EXISTS receipts_sequence; | |
28 | ||
29 | SELECT setval('receipts_sequence', ( | |
30 | SELECT COALESCE(MAX(stream_id), 1) FROM receipts_linearized | |
31 | )); |
0 | /* Copyright 2021 The Matrix.org Foundation C.I.C | |
1 | * | |
2 | * Licensed under the Apache License, Version 2.0 (the "License"); | |
3 | * you may not use this file except in compliance with the License. | |
4 | * You may obtain a copy of the License at | |
5 | * | |
6 | * http://www.apache.org/licenses/LICENSE-2.0 | |
7 | * | |
8 | * Unless required by applicable law or agreed to in writing, software | |
9 | * distributed under the License is distributed on an "AS IS" BASIS, | |
10 | * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |
11 | * See the License for the specific language governing permissions and | |
12 | * limitations under the License. | |
13 | */ | |
14 | ||
15 | -- We incorrectly populated these, so we delete them and let the | |
16 | -- MultiWriterIdGenerator repopulate it. | |
17 | DELETE FROM stream_positions WHERE stream_name = 'receipts' OR stream_name = 'account_data'; |
182 | 182 | ) |
183 | 183 | return {row["tag"]: db_to_json(row["content"]) for row in rows} |
184 | 184 | |
185 | ||
186 | class TagsStore(TagsWorkerStore): | |
187 | 185 | async def add_tag_to_room( |
188 | 186 | self, user_id: str, room_id: str, tag: str, content: JsonDict |
189 | 187 | ) -> int: |
198 | 196 | Returns: |
199 | 197 | The next account data ID. |
200 | 198 | """ |
199 | assert self._can_write_to_account_data | |
200 | ||
201 | 201 | content_json = json_encoder.encode(content) |
202 | 202 | |
203 | 203 | def add_tag_txn(txn, next_id): |
222 | 222 | Returns: |
223 | 223 | The next account data ID. |
224 | 224 | """ |
225 | assert self._can_write_to_account_data | |
225 | 226 | |
226 | 227 | def remove_tag_txn(txn, next_id): |
227 | 228 | sql = ( |
249 | 250 | room_id: The ID of the room. |
250 | 251 | next_id: The the revision to advance to. |
251 | 252 | """ |
253 | assert self._can_write_to_account_data | |
252 | 254 | |
253 | 255 | txn.call_after( |
254 | 256 | self._account_data_stream_cache.entity_has_changed, user_id, next_id |
255 | 257 | ) |
256 | ||
257 | # Note: This is only here for backwards compat to allow admins to | |
258 | # roll back to a previous Synapse version. Next time we update the | |
259 | # database version we can remove this table. | |
260 | update_max_id_sql = ( | |
261 | "UPDATE account_data_max_stream_id" | |
262 | " SET stream_id = ?" | |
263 | " WHERE stream_id < ?" | |
264 | ) | |
265 | txn.execute(update_max_id_sql, (next_id, next_id)) | |
266 | 258 | |
267 | 259 | update_sql = ( |
268 | 260 | "UPDATE room_tags_revisions" |
287 | 279 | # which stream_id ends up in the table, as long as it is higher |
288 | 280 | # than the id that the client has. |
289 | 281 | pass |
282 | ||
283 | ||
284 | class TagsStore(TagsWorkerStore): | |
285 | pass |
463 | 463 | txn: LoggingTransaction, now_time_ms: int, after_destination: Optional[str] |
464 | 464 | ) -> List[str]: |
465 | 465 | q = """ |
466 | SELECT destination FROM destinations | |
467 | WHERE destination IN ( | |
468 | SELECT destination FROM destination_rooms | |
469 | WHERE destination_rooms.stream_ordering > | |
470 | destinations.last_successful_stream_ordering | |
471 | ) | |
472 | AND destination > ? | |
473 | AND ( | |
474 | retry_last_ts IS NULL OR | |
475 | retry_last_ts + retry_interval < ? | |
476 | ) | |
477 | ORDER BY destination | |
478 | LIMIT 25 | |
466 | SELECT DISTINCT destination FROM destinations | |
467 | INNER JOIN destination_rooms USING (destination) | |
468 | WHERE | |
469 | stream_ordering > last_successful_stream_ordering | |
470 | AND destination > ? | |
471 | AND ( | |
472 | retry_last_ts IS NULL OR | |
473 | retry_last_ts + retry_interval < ? | |
474 | ) | |
475 | ORDER BY destination | |
476 | LIMIT 25 | |
479 | 477 | """ |
480 | 478 | txn.execute( |
481 | 479 | q, |
34 | 34 | |
35 | 35 | # Remember to update this number every time a change is made to database |
36 | 36 | # schema files, so the users will be informed on server restarts. |
37 | # XXX: If you're about to bump this to 59 (or higher) please create an update | |
38 | # that drops the unused `cache_invalidation_stream` table, as per #7436! | |
39 | # XXX: Also add an update to drop `account_data_max_stream_id` as per #7656! | |
40 | SCHEMA_VERSION = 58 | |
37 | SCHEMA_VERSION = 59 | |
41 | 38 | |
42 | 39 | dir_path = os.path.abspath(os.path.dirname(__file__)) |
43 | 40 | |
374 | 371 | specific_engine_extensions = (".sqlite", ".postgres") |
375 | 372 | |
376 | 373 | for v in range(start_ver, SCHEMA_VERSION + 1): |
377 | logger.info("Applying schema deltas for v%d", v) | |
374 | if not is_worker: | |
375 | logger.info("Applying schema deltas for v%d", v) | |
376 | ||
377 | cur.execute("DELETE FROM schema_version") | |
378 | cur.execute( | |
379 | "INSERT INTO schema_version (version, upgraded) VALUES (?,?)", | |
380 | (v, True), | |
381 | ) | |
382 | else: | |
383 | logger.info("Checking schema deltas for v%d", v) | |
378 | 384 | |
379 | 385 | # We need to search both the global and per data store schema |
380 | 386 | # directories for schema updates. |
488 | 494 | (v, relative_path), |
489 | 495 | ) |
490 | 496 | |
491 | cur.execute("DELETE FROM schema_version") | |
492 | cur.execute( | |
493 | "INSERT INTO schema_version (version, upgraded) VALUES (?,?)", | |
494 | (v, True), | |
495 | ) | |
496 | ||
497 | 497 | logger.info("Schema now up to date") |
498 | 498 | |
499 | 499 |
16 | 16 | import threading |
17 | 17 | from collections import deque |
18 | 18 | from contextlib import contextmanager |
19 | from typing import Dict, List, Optional, Set, Union | |
19 | from typing import Dict, List, Optional, Set, Tuple, Union | |
20 | 20 | |
21 | 21 | import attr |
22 | 22 | from typing_extensions import Deque |
185 | 185 | Args: |
186 | 186 | db_conn |
187 | 187 | db |
188 | stream_name: A name for the stream. | |
188 | stream_name: A name for the stream, for use in the `stream_positions` | |
189 | table. (Does not need to be the same as the replication stream name) | |
189 | 190 | instance_name: The name of this instance. |
190 | table: Database table associated with stream. | |
191 | instance_column: Column that stores the row's writer's instance name | |
192 | id_column: Column that stores the stream ID. | |
191 | tables: List of tables associated with the stream. Tuple of table | |
192 | name, column name that stores the writer's instance name, and | |
193 | column name that stores the stream ID. | |
193 | 194 | sequence_name: The name of the postgres sequence used to generate new |
194 | 195 | IDs. |
195 | 196 | writers: A list of known writers to use to populate current positions |
205 | 206 | db: DatabasePool, |
206 | 207 | stream_name: str, |
207 | 208 | instance_name: str, |
208 | table: str, | |
209 | instance_column: str, | |
210 | id_column: str, | |
209 | tables: List[Tuple[str, str, str]], | |
211 | 210 | sequence_name: str, |
212 | 211 | writers: List[str], |
213 | 212 | positive: bool = True, |
259 | 258 | self._sequence_gen = PostgresSequenceGenerator(sequence_name) |
260 | 259 | |
261 | 260 | # We check that the table and sequence haven't diverged. |
262 | self._sequence_gen.check_consistency( | |
263 | db_conn, table=table, id_column=id_column, positive=positive | |
264 | ) | |
261 | for table, _, id_column in tables: | |
262 | self._sequence_gen.check_consistency( | |
263 | db_conn, | |
264 | table=table, | |
265 | id_column=id_column, | |
266 | stream_name=stream_name, | |
267 | positive=positive, | |
268 | ) | |
265 | 269 | |
266 | 270 | # This goes and fills out the above state from the database. |
267 | self._load_current_ids(db_conn, table, instance_column, id_column) | |
271 | self._load_current_ids(db_conn, tables) | |
268 | 272 | |
269 | 273 | def _load_current_ids( |
270 | self, db_conn, table: str, instance_column: str, id_column: str | |
274 | self, db_conn, tables: List[Tuple[str, str, str]], | |
271 | 275 | ): |
272 | 276 | cur = db_conn.cursor(txn_name="_load_current_ids") |
273 | 277 | |
305 | 309 | # We add a GREATEST here to ensure that the result is always |
306 | 310 | # positive. (This can be a problem for e.g. backfill streams where |
307 | 311 | # the server has never backfilled). |
308 | sql = """ | |
309 | SELECT GREATEST(COALESCE(%(agg)s(%(id)s), 1), 1) | |
310 | FROM %(table)s | |
311 | """ % { | |
312 | "id": id_column, | |
313 | "table": table, | |
314 | "agg": "MAX" if self._positive else "-MIN", | |
315 | } | |
316 | cur.execute(sql) | |
317 | (stream_id,) = cur.fetchone() | |
318 | self._persisted_upto_position = stream_id | |
312 | max_stream_id = 1 | |
313 | for table, _, id_column in tables: | |
314 | sql = """ | |
315 | SELECT GREATEST(COALESCE(%(agg)s(%(id)s), 1), 1) | |
316 | FROM %(table)s | |
317 | """ % { | |
318 | "id": id_column, | |
319 | "table": table, | |
320 | "agg": "MAX" if self._positive else "-MIN", | |
321 | } | |
322 | cur.execute(sql) | |
323 | (stream_id,) = cur.fetchone() | |
324 | ||
325 | max_stream_id = max(max_stream_id, stream_id) | |
326 | ||
327 | self._persisted_upto_position = max_stream_id | |
319 | 328 | else: |
320 | 329 | # If we have a min_stream_id then we pull out everything greater |
321 | 330 | # than it from the DB so that we can prefill |
328 | 337 | # stream positions table before restart (or the stream position |
329 | 338 | # table otherwise got out of date). |
330 | 339 | |
331 | sql = """ | |
332 | SELECT %(instance)s, %(id)s FROM %(table)s | |
333 | WHERE ? %(cmp)s %(id)s | |
334 | """ % { | |
335 | "id": id_column, | |
336 | "table": table, | |
337 | "instance": instance_column, | |
338 | "cmp": "<=" if self._positive else ">=", | |
339 | } | |
340 | cur.execute(sql, (min_stream_id * self._return_factor,)) | |
341 | ||
342 | 340 | self._persisted_upto_position = min_stream_id |
343 | 341 | |
342 | rows = [] | |
343 | for table, instance_column, id_column in tables: | |
344 | sql = """ | |
345 | SELECT %(instance)s, %(id)s FROM %(table)s | |
346 | WHERE ? %(cmp)s %(id)s | |
347 | """ % { | |
348 | "id": id_column, | |
349 | "table": table, | |
350 | "instance": instance_column, | |
351 | "cmp": "<=" if self._positive else ">=", | |
352 | } | |
353 | cur.execute(sql, (min_stream_id * self._return_factor,)) | |
354 | ||
355 | rows.extend(cur) | |
356 | ||
357 | # Sort so that we handle rows in order for each instance. | |
358 | rows.sort() | |
359 | ||
344 | 360 | with self._lock: |
345 | for (instance, stream_id,) in cur: | |
361 | for (instance, stream_id,) in rows: | |
346 | 362 | stream_id = self._return_factor * stream_id |
347 | 363 | self._add_persisted_position(stream_id) |
348 | 364 |
14 | 14 | import abc |
15 | 15 | import logging |
16 | 16 | import threading |
17 | from typing import Callable, List, Optional | |
18 | ||
19 | from synapse.storage.database import LoggingDatabaseConnection | |
17 | from typing import TYPE_CHECKING, Callable, List, Optional | |
18 | ||
20 | 19 | from synapse.storage.engines import ( |
21 | 20 | BaseDatabaseEngine, |
22 | 21 | IncorrectDatabaseSetup, |
24 | 23 | ) |
25 | 24 | from synapse.storage.types import Connection, Cursor |
26 | 25 | |
26 | if TYPE_CHECKING: | |
27 | from synapse.storage.database import LoggingDatabaseConnection | |
28 | ||
27 | 29 | logger = logging.getLogger(__name__) |
28 | 30 | |
29 | 31 | |
42 | 44 | See docs/postgres.md for more information. |
43 | 45 | """ |
44 | 46 | |
47 | _INCONSISTENT_STREAM_ERROR = """ | |
48 | Postgres sequence '%(seq)s' is inconsistent with associated stream position | |
49 | of '%(stream_name)s' in the 'stream_positions' table. | |
50 | ||
51 | This is likely a programming error and should be reported at | |
52 | https://github.com/matrix-org/synapse. | |
53 | ||
54 | A temporary workaround to fix this error is to shut down Synapse (including | |
55 | any and all workers) and run the following SQL: | |
56 | ||
57 | DELETE FROM stream_positions WHERE stream_name = '%(stream_name)s'; | |
58 | ||
59 | This will need to be done every time the server is restarted. | |
60 | """ | |
61 | ||
45 | 62 | |
46 | 63 | class SequenceGenerator(metaclass=abc.ABCMeta): |
47 | 64 | """A class which generates a unique sequence of integers""" |
54 | 71 | @abc.abstractmethod |
55 | 72 | def check_consistency( |
56 | 73 | self, |
57 | db_conn: LoggingDatabaseConnection, | |
74 | db_conn: "LoggingDatabaseConnection", | |
58 | 75 | table: str, |
59 | 76 | id_column: str, |
77 | stream_name: Optional[str] = None, | |
60 | 78 | positive: bool = True, |
61 | 79 | ): |
62 | 80 | """Should be called during start up to test that the current value of |
63 | 81 | the sequence is greater than or equal to the maximum ID in the table. |
64 | 82 | |
65 | This is to handle various cases where the sequence value can get out | |
66 | of sync with the table, e.g. if Synapse gets rolled back to a previous | |
83 | This is to handle various cases where the sequence value can get out of | |
84 | sync with the table, e.g. if Synapse gets rolled back to a previous | |
67 | 85 | version and the rolled forwards again. |
86 | ||
87 | If a stream name is given then this will check that any value in the | |
88 | `stream_positions` table is less than or equal to the current sequence | |
89 | value. If it isn't then it's likely that streams have been crossed | |
90 | somewhere (e.g. two ID generators have the same stream name). | |
68 | 91 | """ |
69 | 92 | ... |
70 | 93 | |
87 | 110 | |
88 | 111 | def check_consistency( |
89 | 112 | self, |
90 | db_conn: LoggingDatabaseConnection, | |
113 | db_conn: "LoggingDatabaseConnection", | |
91 | 114 | table: str, |
92 | 115 | id_column: str, |
116 | stream_name: Optional[str] = None, | |
93 | 117 | positive: bool = True, |
94 | 118 | ): |
119 | """See SequenceGenerator.check_consistency for docstring. | |
120 | """ | |
121 | ||
95 | 122 | txn = db_conn.cursor(txn_name="sequence.check_consistency") |
96 | 123 | |
97 | 124 | # First we get the current max ID from the table. |
115 | 142 | "SELECT last_value, is_called FROM %(seq)s" % {"seq": self._sequence_name} |
116 | 143 | ) |
117 | 144 | last_value, is_called = txn.fetchone() |
145 | ||
146 | # If we have an associated stream check the stream_positions table. | |
147 | max_in_stream_positions = None | |
148 | if stream_name: | |
149 | txn.execute( | |
150 | "SELECT MAX(stream_id) FROM stream_positions WHERE stream_name = ?", | |
151 | (stream_name,), | |
152 | ) | |
153 | row = txn.fetchone() | |
154 | if row: | |
155 | max_in_stream_positions = row[0] | |
156 | ||
118 | 157 | txn.close() |
119 | 158 | |
120 | 159 | # If `is_called` is False then `last_value` is actually the value that |
133 | 172 | raise IncorrectDatabaseSetup( |
134 | 173 | _INCONSISTENT_SEQUENCE_ERROR |
135 | 174 | % {"seq": self._sequence_name, "table": table, "max_id_sql": table_sql} |
175 | ) | |
176 | ||
177 | # If we have values in the stream positions table then they have to be | |
178 | # less than or equal to `last_value` | |
179 | if max_in_stream_positions and max_in_stream_positions > last_value: | |
180 | raise IncorrectDatabaseSetup( | |
181 | _INCONSISTENT_STREAM_ERROR | |
182 | % {"seq": self._sequence_name, "stream_name": stream_name} | |
136 | 183 | ) |
137 | 184 | |
138 | 185 | |
172 | 219 | return self._current_max_id |
173 | 220 | |
174 | 221 | def check_consistency( |
175 | self, db_conn: Connection, table: str, id_column: str, positive: bool = True | |
222 | self, | |
223 | db_conn: Connection, | |
224 | table: str, | |
225 | id_column: str, | |
226 | stream_name: Optional[str] = None, | |
227 | positive: bool = True, | |
176 | 228 | ): |
177 | 229 | # There is nothing to do for in memory sequences |
178 | 230 | pass |
36 | 36 | from unpaddedbase64 import decode_base64 |
37 | 37 | |
38 | 38 | from synapse.api.errors import Codes, SynapseError |
39 | from synapse.util.stringutils import parse_and_validate_server_name | |
39 | 40 | |
40 | 41 | if TYPE_CHECKING: |
41 | 42 | from synapse.appservice.api import ApplicationService |
256 | 257 | |
257 | 258 | @classmethod |
258 | 259 | def is_valid(cls: Type[DS], s: str) -> bool: |
260 | """Parses the input string and attempts to ensure it is valid.""" | |
259 | 261 | try: |
260 | cls.from_string(s) | |
262 | obj = cls.from_string(s) | |
263 | # Apply additional validation to the domain. This is only done | |
264 | # during is_valid (and not part of from_string) since it is | |
265 | # possible for invalid data to exist in room-state, etc. | |
266 | parse_and_validate_server_name(obj.domain) | |
261 | 267 | return True |
262 | 268 | except Exception: |
263 | 269 | return False |
104 | 104 | keylen=keylen, |
105 | 105 | cache_name=name, |
106 | 106 | cache_type=cache_type, |
107 | size_callback=(lambda d: len(d)) if iterable else None, | |
107 | size_callback=(lambda d: len(d) or 1) if iterable else None, | |
108 | 108 | metrics_collection_callback=metrics_cb, |
109 | 109 | apply_cache_factor_from_config=apply_cache_factor_from_config, |
110 | 110 | ) # type: LruCache[KT, VT] |
12 | 12 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. |
13 | 13 | # See the License for the specific language governing permissions and |
14 | 14 | # limitations under the License. |
15 | import heapq | |
15 | 16 | from itertools import islice |
16 | from typing import Iterable, Iterator, Sequence, Tuple, TypeVar | |
17 | from typing import ( | |
18 | Dict, | |
19 | Generator, | |
20 | Iterable, | |
21 | Iterator, | |
22 | Mapping, | |
23 | Sequence, | |
24 | Set, | |
25 | Tuple, | |
26 | TypeVar, | |
27 | ) | |
28 | ||
29 | from synapse.types import Collection | |
17 | 30 | |
18 | 31 | T = TypeVar("T") |
19 | 32 | |
45 | 58 | If the input is empty, no chunks are returned. |
46 | 59 | """ |
47 | 60 | return (iseq[i : i + maxlen] for i in range(0, len(iseq), maxlen)) |
61 | ||
62 | ||
63 | def sorted_topologically( | |
64 | nodes: Iterable[T], graph: Mapping[T, Collection[T]], | |
65 | ) -> Generator[T, None, None]: | |
66 | """Given a set of nodes and a graph, yield the nodes in toplogical order. | |
67 | ||
68 | For example `sorted_topologically([1, 2], {1: [2]})` will yield `2, 1`. | |
69 | """ | |
70 | ||
71 | # This is implemented by Kahn's algorithm. | |
72 | ||
73 | degree_map = {node: 0 for node in nodes} | |
74 | reverse_graph = {} # type: Dict[T, Set[T]] | |
75 | ||
76 | for node, edges in graph.items(): | |
77 | if node not in degree_map: | |
78 | continue | |
79 | ||
80 | for edge in set(edges): | |
81 | if edge in degree_map: | |
82 | degree_map[node] += 1 | |
83 | ||
84 | reverse_graph.setdefault(edge, set()).add(node) | |
85 | reverse_graph.setdefault(node, set()) | |
86 | ||
87 | zero_degree = [node for node, degree in degree_map.items() if degree == 0] | |
88 | heapq.heapify(zero_degree) | |
89 | ||
90 | while zero_degree: | |
91 | node = heapq.heappop(zero_degree) | |
92 | yield node | |
93 | ||
94 | for edge in reverse_graph.get(node, []): | |
95 | if edge in degree_map: | |
96 | degree_map[edge] -= 1 | |
97 | if degree_map[edge] == 0: | |
98 | heapq.heappush(zero_degree, edge) |
107 | 107 | def __init__(self, clock, name): |
108 | 108 | self.clock = clock |
109 | 109 | self.name = name |
110 | parent_context = current_context() | |
110 | curr_context = current_context() | |
111 | if not curr_context: | |
112 | logger.warning( | |
113 | "Starting metrics collection %r from sentinel context: metrics will be lost", | |
114 | name, | |
115 | ) | |
116 | parent_context = None | |
117 | else: | |
118 | assert isinstance(curr_context, LoggingContext) | |
119 | parent_context = curr_context | |
111 | 120 | self._logging_context = LoggingContext( |
112 | 121 | "Measure[%s]" % (self.name,), parent_context |
113 | 122 | ) |
17 | 17 | import re |
18 | 18 | import string |
19 | 19 | from collections.abc import Iterable |
20 | from typing import Optional, Tuple | |
20 | 21 | |
21 | 22 | from synapse.api.errors import Codes, SynapseError |
22 | 23 | |
24 | 25 | |
25 | 26 | # https://matrix.org/docs/spec/client_server/r0.6.0#post-matrix-client-r0-register-email-requesttoken |
26 | 27 | client_secret_regex = re.compile(r"^[0-9a-zA-Z\.\=\_\-]+$") |
28 | ||
29 | # https://matrix.org/docs/spec/client_server/r0.6.1#matrix-content-mxc-uris, | |
30 | # together with https://github.com/matrix-org/matrix-doc/issues/2177 which basically | |
31 | # says "there is no grammar for media ids" | |
32 | # | |
33 | # The server_name part of this is purposely lax: use parse_and_validate_mxc for | |
34 | # additional validation. | |
35 | # | |
36 | MXC_REGEX = re.compile("^mxc://([^/]+)/([^/#?]+)$") | |
27 | 37 | |
28 | 38 | # random_string and random_string_with_symbols are used for a range of things, |
29 | 39 | # some cryptographically important, some less so. We use SystemRandom to make sure |
58 | 68 | ) |
59 | 69 | |
60 | 70 | |
71 | def parse_server_name(server_name: str) -> Tuple[str, Optional[int]]: | |
72 | """Split a server name into host/port parts. | |
73 | ||
74 | Args: | |
75 | server_name: server name to parse | |
76 | ||
77 | Returns: | |
78 | host/port parts. | |
79 | ||
80 | Raises: | |
81 | ValueError if the server name could not be parsed. | |
82 | """ | |
83 | try: | |
84 | if server_name[-1] == "]": | |
85 | # ipv6 literal, hopefully | |
86 | return server_name, None | |
87 | ||
88 | domain_port = server_name.rsplit(":", 1) | |
89 | domain = domain_port[0] | |
90 | port = int(domain_port[1]) if domain_port[1:] else None | |
91 | return domain, port | |
92 | except Exception: | |
93 | raise ValueError("Invalid server name '%s'" % server_name) | |
94 | ||
95 | ||
96 | VALID_HOST_REGEX = re.compile("\\A[0-9a-zA-Z.-]+\\Z") | |
97 | ||
98 | ||
99 | def parse_and_validate_server_name(server_name: str) -> Tuple[str, Optional[int]]: | |
100 | """Split a server name into host/port parts and do some basic validation. | |
101 | ||
102 | Args: | |
103 | server_name: server name to parse | |
104 | ||
105 | Returns: | |
106 | host/port parts. | |
107 | ||
108 | Raises: | |
109 | ValueError if the server name could not be parsed. | |
110 | """ | |
111 | host, port = parse_server_name(server_name) | |
112 | ||
113 | # these tests don't need to be bulletproof as we'll find out soon enough | |
114 | # if somebody is giving us invalid data. What we *do* need is to be sure | |
115 | # that nobody is sneaking IP literals in that look like hostnames, etc. | |
116 | ||
117 | # look for ipv6 literals | |
118 | if host[0] == "[": | |
119 | if host[-1] != "]": | |
120 | raise ValueError("Mismatched [...] in server name '%s'" % (server_name,)) | |
121 | return host, port | |
122 | ||
123 | # otherwise it should only be alphanumerics. | |
124 | if not VALID_HOST_REGEX.match(host): | |
125 | raise ValueError( | |
126 | "Server name '%s' contains invalid characters" % (server_name,) | |
127 | ) | |
128 | ||
129 | return host, port | |
130 | ||
131 | ||
132 | def parse_and_validate_mxc_uri(mxc: str) -> Tuple[str, Optional[int], str]: | |
133 | """Parse the given string as an MXC URI | |
134 | ||
135 | Checks that the "server name" part is a valid server name | |
136 | ||
137 | Args: | |
138 | mxc: the (alleged) MXC URI to be checked | |
139 | Returns: | |
140 | hostname, port, media id | |
141 | Raises: | |
142 | ValueError if the URI cannot be parsed | |
143 | """ | |
144 | m = MXC_REGEX.match(mxc) | |
145 | if not m: | |
146 | raise ValueError("mxc URI %r did not match expected format" % (mxc,)) | |
147 | server_name = m.group(1) | |
148 | media_id = m.group(2) | |
149 | host, port = parse_and_validate_server_name(server_name) | |
150 | return host, port, media_id | |
151 | ||
152 | ||
61 | 153 | def shortstr(iterable: Iterable, maxitems: int = 5) -> str: |
62 | 154 | """If iterable has maxitems or fewer, return the stringification of a list |
63 | 155 | containing those items. |
74 | 166 | if len(items) <= maxitems: |
75 | 167 | return str(items) |
76 | 168 | return "[" + ", ".join(repr(r) for r in items[:maxitems]) + ", ...]" |
169 | ||
170 | ||
171 | def strtobool(val: str) -> bool: | |
172 | """Convert a string representation of truth to True or False | |
173 | ||
174 | True values are 'y', 'yes', 't', 'true', 'on', and '1'; false values | |
175 | are 'n', 'no', 'f', 'false', 'off', and '0'. Raises ValueError if | |
176 | 'val' is anything else. | |
177 | ||
178 | This is lifted from distutils.util.strtobool, with the exception that it actually | |
179 | returns a bool, rather than an int. | |
180 | """ | |
181 | val = val.lower() | |
182 | if val in ("y", "yes", "t", "true", "on", "1"): | |
183 | return True | |
184 | elif val in ("n", "no", "f", "false", "off", "0"): | |
185 | return False | |
186 | else: | |
187 | raise ValueError("invalid truth value %r" % (val,)) |
0 | # -*- coding: utf-8 -*- | |
1 | # Copyright 2021 The Matrix.org Foundation C.I.C. | |
2 | # | |
3 | # Licensed under the Apache License, Version 2.0 (the "License"); | |
4 | # you may not use this file except in compliance with the License. | |
5 | # You may obtain a copy of the License at | |
6 | # | |
7 | # http://www.apache.org/licenses/LICENSE-2.0 | |
8 | # | |
9 | # Unless required by applicable law or agreed to in writing, software | |
10 | # distributed under the License is distributed on an "AS IS" BASIS, | |
11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |
12 | # See the License for the specific language governing permissions and | |
13 | # limitations under the License. | |
14 | ||
15 | from synapse.config import ConfigError | |
16 | from synapse.config._util import validate_config | |
17 | ||
18 | from tests.unittest import TestCase | |
19 | ||
20 | ||
21 | class ValidateConfigTestCase(TestCase): | |
22 | """Test cases for synapse.config._util.validate_config""" | |
23 | ||
24 | def test_bad_object_in_array(self): | |
25 | """malformed objects within an array should be validated correctly""" | |
26 | ||
27 | # consider a structure: | |
28 | # | |
29 | # array_of_objs: | |
30 | # - r: 1 | |
31 | # foo: 2 | |
32 | # | |
33 | # - r: 2 | |
34 | # bar: 3 | |
35 | # | |
36 | # ... where each entry must contain an "r": check that the path | |
37 | # to the required item is correclty reported. | |
38 | ||
39 | schema = { | |
40 | "type": "object", | |
41 | "properties": { | |
42 | "array_of_objs": { | |
43 | "type": "array", | |
44 | "items": {"type": "object", "required": ["r"]}, | |
45 | }, | |
46 | }, | |
47 | } | |
48 | ||
49 | with self.assertRaises(ConfigError) as c: | |
50 | validate_config(schema, {"array_of_objs": [{}]}, ("base",)) | |
51 | ||
52 | self.assertEqual(c.exception.path, ["base", "array_of_objs", "<item 0>"]) |
33 | 33 | |
34 | 34 | |
35 | 35 | class PruneEventTestCase(unittest.TestCase): |
36 | """ Asserts that a new event constructed with `evdict` will look like | |
37 | `matchdict` when it is redacted. """ | |
38 | ||
39 | 36 | def run_test(self, evdict, matchdict, **kwargs): |
40 | self.assertEquals( | |
37 | """ | |
38 | Asserts that a new event constructed with `evdict` will look like | |
39 | `matchdict` when it is redacted. | |
40 | ||
41 | Args: | |
42 | evdict: The dictionary to build the event from. | |
43 | matchdict: The expected resulting dictionary. | |
44 | kwargs: Additional keyword arguments used to create the event. | |
45 | """ | |
46 | self.assertEqual( | |
41 | 47 | prune_event(make_event_from_dict(evdict, **kwargs)).get_dict(), matchdict |
42 | 48 | ) |
43 | 49 | |
54 | 60 | ) |
55 | 61 | |
56 | 62 | def test_basic_keys(self): |
57 | self.run_test( | |
58 | { | |
63 | """Ensure that the keys that should be untouched are kept.""" | |
64 | # Note that some of the values below don't really make sense, but the | |
65 | # pruning of events doesn't worry about the values of any fields (with | |
66 | # the exception of the content field). | |
67 | self.run_test( | |
68 | { | |
69 | "event_id": "$3:domain", | |
59 | 70 | "type": "A", |
60 | 71 | "room_id": "!1:domain", |
61 | 72 | "sender": "@2:domain", |
73 | "state_key": "B", | |
74 | "content": {"other_key": "foo"}, | |
75 | "hashes": "hashes", | |
76 | "signatures": {"domain": {"algo:1": "sigs"}}, | |
77 | "depth": 4, | |
78 | "prev_events": "prev_events", | |
79 | "prev_state": "prev_state", | |
80 | "auth_events": "auth_events", | |
81 | "origin": "domain", | |
82 | "origin_server_ts": 1234, | |
83 | "membership": "join", | |
84 | # Also include a key that should be removed. | |
85 | "other_key": "foo", | |
86 | }, | |
87 | { | |
62 | 88 | "event_id": "$3:domain", |
63 | "origin": "domain", | |
64 | }, | |
65 | { | |
66 | 89 | "type": "A", |
67 | 90 | "room_id": "!1:domain", |
68 | 91 | "sender": "@2:domain", |
69 | "event_id": "$3:domain", | |
92 | "state_key": "B", | |
93 | "hashes": "hashes", | |
94 | "depth": 4, | |
95 | "prev_events": "prev_events", | |
96 | "prev_state": "prev_state", | |
97 | "auth_events": "auth_events", | |
70 | 98 | "origin": "domain", |
71 | "content": {}, | |
72 | "signatures": {}, | |
73 | "unsigned": {}, | |
74 | }, | |
75 | ) | |
76 | ||
77 | def test_unsigned_age_ts(self): | |
78 | self.run_test( | |
79 | {"type": "B", "event_id": "$test:domain", "unsigned": {"age_ts": 20}}, | |
99 | "origin_server_ts": 1234, | |
100 | "membership": "join", | |
101 | "content": {}, | |
102 | "signatures": {"domain": {"algo:1": "sigs"}}, | |
103 | "unsigned": {}, | |
104 | }, | |
105 | ) | |
106 | ||
107 | # As of MSC2176 we now redact the membership and prev_states keys. | |
108 | self.run_test( | |
109 | {"type": "A", "prev_state": "prev_state", "membership": "join"}, | |
110 | {"type": "A", "content": {}, "signatures": {}, "unsigned": {}}, | |
111 | room_version=RoomVersions.MSC2176, | |
112 | ) | |
113 | ||
114 | def test_unsigned(self): | |
115 | """Ensure that unsigned properties get stripped (except age_ts and replaces_state).""" | |
116 | self.run_test( | |
80 | 117 | { |
81 | 118 | "type": "B", |
82 | 119 | "event_id": "$test:domain", |
83 | "content": {}, | |
84 | "signatures": {}, | |
85 | "unsigned": {"age_ts": 20}, | |
86 | }, | |
87 | ) | |
88 | ||
89 | self.run_test( | |
120 | "unsigned": { | |
121 | "age_ts": 20, | |
122 | "replaces_state": "$test2:domain", | |
123 | "other_key": "foo", | |
124 | }, | |
125 | }, | |
90 | 126 | { |
91 | 127 | "type": "B", |
92 | 128 | "event_id": "$test:domain", |
93 | "unsigned": {"other_key": "here"}, | |
94 | }, | |
95 | { | |
96 | "type": "B", | |
97 | "event_id": "$test:domain", | |
98 | "content": {}, | |
99 | "signatures": {}, | |
100 | "unsigned": {}, | |
129 | "content": {}, | |
130 | "signatures": {}, | |
131 | "unsigned": {"age_ts": 20, "replaces_state": "$test2:domain"}, | |
101 | 132 | }, |
102 | 133 | ) |
103 | 134 | |
104 | 135 | def test_content(self): |
136 | """The content dictionary should be stripped in most cases.""" | |
105 | 137 | self.run_test( |
106 | 138 | {"type": "C", "event_id": "$test:domain", "content": {"things": "here"}}, |
107 | 139 | { |
113 | 145 | }, |
114 | 146 | ) |
115 | 147 | |
148 | # Some events keep a single content key/value. | |
149 | EVENT_KEEP_CONTENT_KEYS = [ | |
150 | ("member", "membership", "join"), | |
151 | ("join_rules", "join_rule", "invite"), | |
152 | ("history_visibility", "history_visibility", "shared"), | |
153 | ] | |
154 | for event_type, key, value in EVENT_KEEP_CONTENT_KEYS: | |
155 | self.run_test( | |
156 | { | |
157 | "type": "m.room." + event_type, | |
158 | "event_id": "$test:domain", | |
159 | "content": {key: value, "other_key": "foo"}, | |
160 | }, | |
161 | { | |
162 | "type": "m.room." + event_type, | |
163 | "event_id": "$test:domain", | |
164 | "content": {key: value}, | |
165 | "signatures": {}, | |
166 | "unsigned": {}, | |
167 | }, | |
168 | ) | |
169 | ||
170 | def test_create(self): | |
171 | """Create events are partially redacted until MSC2176.""" | |
116 | 172 | self.run_test( |
117 | 173 | { |
118 | 174 | "type": "m.room.create", |
119 | 175 | "event_id": "$test:domain", |
120 | "content": {"creator": "@2:domain", "other_field": "here"}, | |
176 | "content": {"creator": "@2:domain", "other_key": "foo"}, | |
121 | 177 | }, |
122 | 178 | { |
123 | 179 | "type": "m.room.create", |
126 | 182 | "signatures": {}, |
127 | 183 | "unsigned": {}, |
128 | 184 | }, |
185 | ) | |
186 | ||
187 | # After MSC2176, create events get nothing redacted. | |
188 | self.run_test( | |
189 | {"type": "m.room.create", "content": {"not_a_real_key": True}}, | |
190 | { | |
191 | "type": "m.room.create", | |
192 | "content": {"not_a_real_key": True}, | |
193 | "signatures": {}, | |
194 | "unsigned": {}, | |
195 | }, | |
196 | room_version=RoomVersions.MSC2176, | |
197 | ) | |
198 | ||
199 | def test_power_levels(self): | |
200 | """Power level events keep a variety of content keys.""" | |
201 | self.run_test( | |
202 | { | |
203 | "type": "m.room.power_levels", | |
204 | "event_id": "$test:domain", | |
205 | "content": { | |
206 | "ban": 1, | |
207 | "events": {"m.room.name": 100}, | |
208 | "events_default": 2, | |
209 | "invite": 3, | |
210 | "kick": 4, | |
211 | "redact": 5, | |
212 | "state_default": 6, | |
213 | "users": {"@admin:domain": 100}, | |
214 | "users_default": 7, | |
215 | "other_key": 8, | |
216 | }, | |
217 | }, | |
218 | { | |
219 | "type": "m.room.power_levels", | |
220 | "event_id": "$test:domain", | |
221 | "content": { | |
222 | "ban": 1, | |
223 | "events": {"m.room.name": 100}, | |
224 | "events_default": 2, | |
225 | # Note that invite is not here. | |
226 | "kick": 4, | |
227 | "redact": 5, | |
228 | "state_default": 6, | |
229 | "users": {"@admin:domain": 100}, | |
230 | "users_default": 7, | |
231 | }, | |
232 | "signatures": {}, | |
233 | "unsigned": {}, | |
234 | }, | |
235 | ) | |
236 | ||
237 | # After MSC2176, power levels events keep the invite key. | |
238 | self.run_test( | |
239 | {"type": "m.room.power_levels", "content": {"invite": 75}}, | |
240 | { | |
241 | "type": "m.room.power_levels", | |
242 | "content": {"invite": 75}, | |
243 | "signatures": {}, | |
244 | "unsigned": {}, | |
245 | }, | |
246 | room_version=RoomVersions.MSC2176, | |
129 | 247 | ) |
130 | 248 | |
131 | 249 | def test_alias_event(self): |
145 | 263 | }, |
146 | 264 | ) |
147 | 265 | |
148 | def test_msc2432_alias_event(self): | |
149 | """After MSC2432, alias events have no special behavior.""" | |
266 | # After MSC2432, alias events have no special behavior. | |
150 | 267 | self.run_test( |
151 | 268 | {"type": "m.room.aliases", "content": {"aliases": ["test"]}}, |
152 | 269 | { |
156 | 273 | "unsigned": {}, |
157 | 274 | }, |
158 | 275 | room_version=RoomVersions.V6, |
276 | ) | |
277 | ||
278 | def test_redacts(self): | |
279 | """Redaction events have no special behaviour until MSC2174/MSC2176.""" | |
280 | ||
281 | self.run_test( | |
282 | {"type": "m.room.redaction", "content": {"redacts": "$test2:domain"}}, | |
283 | { | |
284 | "type": "m.room.redaction", | |
285 | "content": {}, | |
286 | "signatures": {}, | |
287 | "unsigned": {}, | |
288 | }, | |
289 | room_version=RoomVersions.V6, | |
290 | ) | |
291 | ||
292 | # After MSC2174, redaction events keep the redacts content key. | |
293 | self.run_test( | |
294 | {"type": "m.room.redaction", "content": {"redacts": "$test2:domain"}}, | |
295 | { | |
296 | "type": "m.room.redaction", | |
297 | "content": {"redacts": "$test2:domain"}, | |
298 | "signatures": {}, | |
299 | "unsigned": {}, | |
300 | }, | |
301 | room_version=RoomVersions.MSC2176, | |
159 | 302 | ) |
160 | 303 | |
161 | 304 |
117 | 117 | |
118 | 118 | def _mock_request(): |
119 | 119 | """Returns a mock which will stand in as a SynapseRequest""" |
120 | return Mock(spec=["getClientIP", "get_user_agent"]) | |
120 | return Mock(spec=["getClientIP", "getHeader"]) |
12 | 12 | # See the License for the specific language governing permissions and |
13 | 13 | # limitations under the License. |
14 | 14 | import json |
15 | import re | |
16 | from typing import Dict | |
17 | from urllib.parse import parse_qs, urlencode, urlparse | |
15 | from typing import Optional | |
16 | from urllib.parse import parse_qs, urlparse | |
18 | 17 | |
19 | 18 | from mock import ANY, Mock, patch |
20 | 19 | |
21 | 20 | import pymacaroons |
22 | 21 | |
23 | from twisted.web.resource import Resource | |
24 | ||
25 | from synapse.api.errors import RedirectException | |
26 | from synapse.handlers.oidc_handler import OidcError | |
27 | 22 | from synapse.handlers.sso import MappingException |
28 | from synapse.rest.client.v1 import login | |
29 | from synapse.rest.synapse.client.pick_username import pick_username_resource | |
30 | 23 | from synapse.server import HomeServer |
31 | 24 | from synapse.types import UserID |
32 | 25 | |
33 | 26 | from tests.test_utils import FakeResponse, simple_async_mock |
34 | 27 | from tests.unittest import HomeserverTestCase, override_config |
28 | ||
29 | try: | |
30 | import authlib # noqa: F401 | |
31 | ||
32 | HAS_OIDC = True | |
33 | except ImportError: | |
34 | HAS_OIDC = False | |
35 | ||
35 | 36 | |
36 | 37 | # These are a few constants that are used as config parameters in the tests. |
37 | 38 | ISSUER = "https://issuer/" |
112 | 113 | |
113 | 114 | |
114 | 115 | class OidcHandlerTestCase(HomeserverTestCase): |
116 | if not HAS_OIDC: | |
117 | skip = "requires OIDC" | |
118 | ||
115 | 119 | def default_config(self): |
116 | 120 | config = super().default_config() |
117 | 121 | config["public_baseurl"] = BASE_URL |
140 | 144 | hs = self.setup_test_homeserver(proxied_http_client=self.http_client) |
141 | 145 | |
142 | 146 | self.handler = hs.get_oidc_handler() |
147 | self.provider = self.handler._providers["oidc"] | |
143 | 148 | sso_handler = hs.get_sso_handler() |
144 | 149 | # Mock the render error method. |
145 | 150 | self.render_error = Mock(return_value=None) |
151 | 156 | return hs |
152 | 157 | |
153 | 158 | def metadata_edit(self, values): |
154 | return patch.dict(self.handler._provider_metadata, values) | |
159 | return patch.dict(self.provider._provider_metadata, values) | |
155 | 160 | |
156 | 161 | def assertRenderedError(self, error, error_description=None): |
162 | self.render_error.assert_called_once() | |
157 | 163 | args = self.render_error.call_args[0] |
158 | 164 | self.assertEqual(args[1], error) |
159 | 165 | if error_description is not None: |
164 | 170 | |
165 | 171 | def test_config(self): |
166 | 172 | """Basic config correctly sets up the callback URL and client auth correctly.""" |
167 | self.assertEqual(self.handler._callback_url, CALLBACK_URL) | |
168 | self.assertEqual(self.handler._client_auth.client_id, CLIENT_ID) | |
169 | self.assertEqual(self.handler._client_auth.client_secret, CLIENT_SECRET) | |
173 | self.assertEqual(self.provider._callback_url, CALLBACK_URL) | |
174 | self.assertEqual(self.provider._client_auth.client_id, CLIENT_ID) | |
175 | self.assertEqual(self.provider._client_auth.client_secret, CLIENT_SECRET) | |
170 | 176 | |
171 | 177 | @override_config({"oidc_config": {"discover": True}}) |
172 | 178 | def test_discovery(self): |
173 | 179 | """The handler should discover the endpoints from OIDC discovery document.""" |
174 | 180 | # This would throw if some metadata were invalid |
175 | metadata = self.get_success(self.handler.load_metadata()) | |
181 | metadata = self.get_success(self.provider.load_metadata()) | |
176 | 182 | self.http_client.get_json.assert_called_once_with(WELL_KNOWN) |
177 | 183 | |
178 | 184 | self.assertEqual(metadata.issuer, ISSUER) |
184 | 190 | |
185 | 191 | # subsequent calls should be cached |
186 | 192 | self.http_client.reset_mock() |
187 | self.get_success(self.handler.load_metadata()) | |
193 | self.get_success(self.provider.load_metadata()) | |
188 | 194 | self.http_client.get_json.assert_not_called() |
189 | 195 | |
190 | 196 | @override_config({"oidc_config": COMMON_CONFIG}) |
191 | 197 | def test_no_discovery(self): |
192 | 198 | """When discovery is disabled, it should not try to load from discovery document.""" |
193 | self.get_success(self.handler.load_metadata()) | |
199 | self.get_success(self.provider.load_metadata()) | |
194 | 200 | self.http_client.get_json.assert_not_called() |
195 | 201 | |
196 | 202 | @override_config({"oidc_config": COMMON_CONFIG}) |
197 | 203 | def test_load_jwks(self): |
198 | 204 | """JWKS loading is done once (then cached) if used.""" |
199 | jwks = self.get_success(self.handler.load_jwks()) | |
205 | jwks = self.get_success(self.provider.load_jwks()) | |
200 | 206 | self.http_client.get_json.assert_called_once_with(JWKS_URI) |
201 | 207 | self.assertEqual(jwks, {"keys": []}) |
202 | 208 | |
203 | 209 | # subsequent calls should be cached… |
204 | 210 | self.http_client.reset_mock() |
205 | self.get_success(self.handler.load_jwks()) | |
211 | self.get_success(self.provider.load_jwks()) | |
206 | 212 | self.http_client.get_json.assert_not_called() |
207 | 213 | |
208 | 214 | # …unless forced |
209 | 215 | self.http_client.reset_mock() |
210 | self.get_success(self.handler.load_jwks(force=True)) | |
216 | self.get_success(self.provider.load_jwks(force=True)) | |
211 | 217 | self.http_client.get_json.assert_called_once_with(JWKS_URI) |
212 | 218 | |
213 | 219 | # Throw if the JWKS uri is missing |
214 | 220 | with self.metadata_edit({"jwks_uri": None}): |
215 | self.get_failure(self.handler.load_jwks(force=True), RuntimeError) | |
221 | self.get_failure(self.provider.load_jwks(force=True), RuntimeError) | |
216 | 222 | |
217 | 223 | # Return empty key set if JWKS are not used |
218 | self.handler._scopes = [] # not asking the openid scope | |
224 | self.provider._scopes = [] # not asking the openid scope | |
219 | 225 | self.http_client.get_json.reset_mock() |
220 | jwks = self.get_success(self.handler.load_jwks(force=True)) | |
226 | jwks = self.get_success(self.provider.load_jwks(force=True)) | |
221 | 227 | self.http_client.get_json.assert_not_called() |
222 | 228 | self.assertEqual(jwks, {"keys": []}) |
223 | 229 | |
224 | 230 | @override_config({"oidc_config": COMMON_CONFIG}) |
225 | 231 | def test_validate_config(self): |
226 | 232 | """Provider metadatas are extensively validated.""" |
227 | h = self.handler | |
233 | h = self.provider | |
228 | 234 | |
229 | 235 | # Default test config does not throw |
230 | 236 | h._validate_metadata() |
303 | 309 | """Provider metadata validation can be disabled by config.""" |
304 | 310 | with self.metadata_edit({"issuer": "http://insecure"}): |
305 | 311 | # This should not throw |
306 | self.handler._validate_metadata() | |
312 | self.provider._validate_metadata() | |
307 | 313 | |
308 | 314 | def test_redirect_request(self): |
309 | 315 | """The redirect request has the right arguments & generates a valid session cookie.""" |
310 | 316 | req = Mock(spec=["addCookie"]) |
311 | 317 | url = self.get_success( |
312 | self.handler.handle_redirect_request(req, b"http://client/redirect") | |
318 | self.provider.handle_redirect_request(req, b"http://client/redirect") | |
313 | 319 | ) |
314 | 320 | url = urlparse(url) |
315 | 321 | auth_endpoint = urlparse(AUTHORIZATION_ENDPOINT) |
338 | 344 | cookie = args[1] |
339 | 345 | |
340 | 346 | macaroon = pymacaroons.Macaroon.deserialize(cookie) |
341 | state = self.handler._get_value_from_macaroon(macaroon, "state") | |
342 | nonce = self.handler._get_value_from_macaroon(macaroon, "nonce") | |
343 | redirect = self.handler._get_value_from_macaroon( | |
347 | state = self.handler._token_generator._get_value_from_macaroon( | |
348 | macaroon, "state" | |
349 | ) | |
350 | nonce = self.handler._token_generator._get_value_from_macaroon( | |
351 | macaroon, "nonce" | |
352 | ) | |
353 | redirect = self.handler._token_generator._get_value_from_macaroon( | |
344 | 354 | macaroon, "client_redirect_url" |
345 | 355 | ) |
346 | 356 | |
373 | 383 | |
374 | 384 | # ensure that we are correctly testing the fallback when "get_extra_attributes" |
375 | 385 | # is not implemented. |
376 | mapping_provider = self.handler._user_mapping_provider | |
386 | mapping_provider = self.provider._user_mapping_provider | |
377 | 387 | with self.assertRaises(AttributeError): |
378 | 388 | _ = mapping_provider.get_extra_attributes |
379 | 389 | |
388 | 398 | "username": username, |
389 | 399 | } |
390 | 400 | expected_user_id = "@%s:%s" % (username, self.hs.hostname) |
391 | self.handler._exchange_code = simple_async_mock(return_value=token) | |
392 | self.handler._parse_id_token = simple_async_mock(return_value=userinfo) | |
393 | self.handler._fetch_userinfo = simple_async_mock(return_value=userinfo) | |
401 | self.provider._exchange_code = simple_async_mock(return_value=token) | |
402 | self.provider._parse_id_token = simple_async_mock(return_value=userinfo) | |
403 | self.provider._fetch_userinfo = simple_async_mock(return_value=userinfo) | |
394 | 404 | auth_handler = self.hs.get_auth_handler() |
395 | 405 | auth_handler.complete_sso_login = simple_async_mock() |
396 | 406 | |
400 | 410 | client_redirect_url = "http://client/redirect" |
401 | 411 | user_agent = "Browser" |
402 | 412 | ip_address = "10.0.0.1" |
403 | session = self.handler._generate_oidc_session_token( | |
404 | state=state, | |
405 | nonce=nonce, | |
406 | client_redirect_url=client_redirect_url, | |
407 | ui_auth_session_id=None, | |
408 | ) | |
413 | session = self._generate_oidc_session_token(state, nonce, client_redirect_url) | |
409 | 414 | request = _build_callback_request( |
410 | 415 | code, state, session, user_agent=user_agent, ip_address=ip_address |
411 | 416 | ) |
415 | 420 | auth_handler.complete_sso_login.assert_called_once_with( |
416 | 421 | expected_user_id, request, client_redirect_url, None, |
417 | 422 | ) |
418 | self.handler._exchange_code.assert_called_once_with(code) | |
419 | self.handler._parse_id_token.assert_called_once_with(token, nonce=nonce) | |
420 | self.handler._fetch_userinfo.assert_not_called() | |
423 | self.provider._exchange_code.assert_called_once_with(code) | |
424 | self.provider._parse_id_token.assert_called_once_with(token, nonce=nonce) | |
425 | self.provider._fetch_userinfo.assert_not_called() | |
421 | 426 | self.render_error.assert_not_called() |
422 | 427 | |
423 | 428 | # Handle mapping errors |
424 | 429 | with patch.object( |
425 | self.handler, | |
430 | self.provider, | |
426 | 431 | "_remote_id_from_userinfo", |
427 | 432 | new=Mock(side_effect=MappingException()), |
428 | 433 | ): |
430 | 435 | self.assertRenderedError("mapping_error") |
431 | 436 | |
432 | 437 | # Handle ID token errors |
433 | self.handler._parse_id_token = simple_async_mock(raises=Exception()) | |
438 | self.provider._parse_id_token = simple_async_mock(raises=Exception()) | |
434 | 439 | self.get_success(self.handler.handle_oidc_callback(request)) |
435 | 440 | self.assertRenderedError("invalid_token") |
436 | 441 | |
437 | 442 | auth_handler.complete_sso_login.reset_mock() |
438 | self.handler._exchange_code.reset_mock() | |
439 | self.handler._parse_id_token.reset_mock() | |
440 | self.handler._fetch_userinfo.reset_mock() | |
443 | self.provider._exchange_code.reset_mock() | |
444 | self.provider._parse_id_token.reset_mock() | |
445 | self.provider._fetch_userinfo.reset_mock() | |
441 | 446 | |
442 | 447 | # With userinfo fetching |
443 | self.handler._scopes = [] # do not ask the "openid" scope | |
448 | self.provider._scopes = [] # do not ask the "openid" scope | |
444 | 449 | self.get_success(self.handler.handle_oidc_callback(request)) |
445 | 450 | |
446 | 451 | auth_handler.complete_sso_login.assert_called_once_with( |
447 | 452 | expected_user_id, request, client_redirect_url, None, |
448 | 453 | ) |
449 | self.handler._exchange_code.assert_called_once_with(code) | |
450 | self.handler._parse_id_token.assert_not_called() | |
451 | self.handler._fetch_userinfo.assert_called_once_with(token) | |
454 | self.provider._exchange_code.assert_called_once_with(code) | |
455 | self.provider._parse_id_token.assert_not_called() | |
456 | self.provider._fetch_userinfo.assert_called_once_with(token) | |
452 | 457 | self.render_error.assert_not_called() |
453 | 458 | |
454 | 459 | # Handle userinfo fetching error |
455 | self.handler._fetch_userinfo = simple_async_mock(raises=Exception()) | |
460 | self.provider._fetch_userinfo = simple_async_mock(raises=Exception()) | |
456 | 461 | self.get_success(self.handler.handle_oidc_callback(request)) |
457 | 462 | self.assertRenderedError("fetch_error") |
458 | 463 | |
459 | 464 | # Handle code exchange failure |
460 | self.handler._exchange_code = simple_async_mock( | |
465 | from synapse.handlers.oidc_handler import OidcError | |
466 | ||
467 | self.provider._exchange_code = simple_async_mock( | |
461 | 468 | raises=OidcError("invalid_request") |
462 | 469 | ) |
463 | 470 | self.get_success(self.handler.handle_oidc_callback(request)) |
487 | 494 | self.assertRenderedError("invalid_session") |
488 | 495 | |
489 | 496 | # Mismatching session |
490 | session = self.handler._generate_oidc_session_token( | |
491 | state="state", | |
492 | nonce="nonce", | |
493 | client_redirect_url="http://client/redirect", | |
494 | ui_auth_session_id=None, | |
497 | session = self._generate_oidc_session_token( | |
498 | state="state", nonce="nonce", client_redirect_url="http://client/redirect", | |
495 | 499 | ) |
496 | 500 | request.args = {} |
497 | 501 | request.args[b"state"] = [b"mismatching state"] |
515 | 519 | return_value=FakeResponse(code=200, phrase=b"OK", body=token_json) |
516 | 520 | ) |
517 | 521 | code = "code" |
518 | ret = self.get_success(self.handler._exchange_code(code)) | |
522 | ret = self.get_success(self.provider._exchange_code(code)) | |
519 | 523 | kwargs = self.http_client.request.call_args[1] |
520 | 524 | |
521 | 525 | self.assertEqual(ret, token) |
537 | 541 | body=b'{"error": "foo", "error_description": "bar"}', |
538 | 542 | ) |
539 | 543 | ) |
540 | exc = self.get_failure(self.handler._exchange_code(code), OidcError) | |
544 | from synapse.handlers.oidc_handler import OidcError | |
545 | ||
546 | exc = self.get_failure(self.provider._exchange_code(code), OidcError) | |
541 | 547 | self.assertEqual(exc.value.error, "foo") |
542 | 548 | self.assertEqual(exc.value.error_description, "bar") |
543 | 549 | |
547 | 553 | code=500, phrase=b"Internal Server Error", body=b"Not JSON", |
548 | 554 | ) |
549 | 555 | ) |
550 | exc = self.get_failure(self.handler._exchange_code(code), OidcError) | |
556 | exc = self.get_failure(self.provider._exchange_code(code), OidcError) | |
551 | 557 | self.assertEqual(exc.value.error, "server_error") |
552 | 558 | |
553 | 559 | # Internal server error with JSON body |
559 | 565 | ) |
560 | 566 | ) |
561 | 567 | |
562 | exc = self.get_failure(self.handler._exchange_code(code), OidcError) | |
568 | exc = self.get_failure(self.provider._exchange_code(code), OidcError) | |
563 | 569 | self.assertEqual(exc.value.error, "internal_server_error") |
564 | 570 | |
565 | 571 | # 4xx error without "error" field |
566 | 572 | self.http_client.request = simple_async_mock( |
567 | 573 | return_value=FakeResponse(code=400, phrase=b"Bad request", body=b"{}",) |
568 | 574 | ) |
569 | exc = self.get_failure(self.handler._exchange_code(code), OidcError) | |
575 | exc = self.get_failure(self.provider._exchange_code(code), OidcError) | |
570 | 576 | self.assertEqual(exc.value.error, "server_error") |
571 | 577 | |
572 | 578 | # 2xx error with "error" field |
575 | 581 | code=200, phrase=b"OK", body=b'{"error": "some_error"}', |
576 | 582 | ) |
577 | 583 | ) |
578 | exc = self.get_failure(self.handler._exchange_code(code), OidcError) | |
584 | exc = self.get_failure(self.provider._exchange_code(code), OidcError) | |
579 | 585 | self.assertEqual(exc.value.error, "some_error") |
580 | 586 | |
581 | 587 | @override_config( |
601 | 607 | "username": "foo", |
602 | 608 | "phone": "1234567", |
603 | 609 | } |
604 | self.handler._exchange_code = simple_async_mock(return_value=token) | |
605 | self.handler._parse_id_token = simple_async_mock(return_value=userinfo) | |
610 | self.provider._exchange_code = simple_async_mock(return_value=token) | |
611 | self.provider._parse_id_token = simple_async_mock(return_value=userinfo) | |
606 | 612 | auth_handler = self.hs.get_auth_handler() |
607 | 613 | auth_handler.complete_sso_login = simple_async_mock() |
608 | 614 | |
609 | 615 | state = "state" |
610 | 616 | client_redirect_url = "http://client/redirect" |
611 | session = self.handler._generate_oidc_session_token( | |
612 | state=state, | |
613 | nonce="nonce", | |
614 | client_redirect_url=client_redirect_url, | |
615 | ui_auth_session_id=None, | |
617 | session = self._generate_oidc_session_token( | |
618 | state=state, nonce="nonce", client_redirect_url=client_redirect_url, | |
616 | 619 | ) |
617 | 620 | request = _build_callback_request("code", state, session) |
618 | 621 | |
826 | 829 | self.get_success(_make_callback_with_userinfo(self.hs, userinfo)) |
827 | 830 | self.assertRenderedError("mapping_error", "localpart is invalid: ") |
828 | 831 | |
829 | ||
830 | class UsernamePickerTestCase(HomeserverTestCase): | |
831 | servlets = [login.register_servlets] | |
832 | ||
833 | def default_config(self): | |
834 | config = super().default_config() | |
835 | config["public_baseurl"] = BASE_URL | |
836 | oidc_config = { | |
837 | "enabled": True, | |
838 | "client_id": CLIENT_ID, | |
839 | "client_secret": CLIENT_SECRET, | |
840 | "issuer": ISSUER, | |
841 | "scopes": SCOPES, | |
842 | "user_mapping_provider": { | |
843 | "config": {"display_name_template": "{{ user.displayname }}"} | |
844 | }, | |
845 | } | |
846 | ||
847 | # Update this config with what's in the default config so that | |
848 | # override_config works as expected. | |
849 | oidc_config.update(config.get("oidc_config", {})) | |
850 | config["oidc_config"] = oidc_config | |
851 | ||
852 | # whitelist this client URI so we redirect straight to it rather than | |
853 | # serving a confirmation page | |
854 | config["sso"] = {"client_whitelist": ["https://whitelisted.client"]} | |
855 | return config | |
856 | ||
857 | def create_resource_dict(self) -> Dict[str, Resource]: | |
858 | d = super().create_resource_dict() | |
859 | d["/_synapse/client/pick_username"] = pick_username_resource(self.hs) | |
860 | return d | |
861 | ||
862 | def test_username_picker(self): | |
863 | """Test the happy path of a username picker flow.""" | |
864 | client_redirect_url = "https://whitelisted.client" | |
865 | ||
866 | # first of all, mock up an OIDC callback to the OidcHandler, which should | |
867 | # raise a RedirectException | |
868 | userinfo = {"sub": "tester", "displayname": "Jonny"} | |
869 | f = self.get_failure( | |
870 | _make_callback_with_userinfo( | |
871 | self.hs, userinfo, client_redirect_url=client_redirect_url | |
832 | def _generate_oidc_session_token( | |
833 | self, | |
834 | state: str, | |
835 | nonce: str, | |
836 | client_redirect_url: str, | |
837 | ui_auth_session_id: Optional[str] = None, | |
838 | ) -> str: | |
839 | from synapse.handlers.oidc_handler import OidcSessionData | |
840 | ||
841 | return self.handler._token_generator.generate_oidc_session_token( | |
842 | state=state, | |
843 | session_data=OidcSessionData( | |
844 | idp_id="oidc", | |
845 | nonce=nonce, | |
846 | client_redirect_url=client_redirect_url, | |
847 | ui_auth_session_id=ui_auth_session_id, | |
872 | 848 | ), |
873 | RedirectException, | |
874 | ) | |
875 | ||
876 | # check the Location and cookies returned by the RedirectException | |
877 | self.assertEqual(f.value.location, b"/_synapse/client/pick_username") | |
878 | cookieheader = f.value.cookies[0] | |
879 | regex = re.compile(b"^username_mapping_session=([a-zA-Z]+);") | |
880 | m = regex.search(cookieheader) | |
881 | if not m: | |
882 | self.fail("cookie header %s does not match %s" % (cookieheader, regex)) | |
883 | ||
884 | # introspect the sso handler a bit to check that the username mapping session | |
885 | # looks ok. | |
886 | session_id = m.group(1).decode("ascii") | |
887 | username_mapping_sessions = self.hs.get_sso_handler()._username_mapping_sessions | |
888 | self.assertIn( | |
889 | session_id, username_mapping_sessions, "session id not found in map" | |
890 | ) | |
891 | session = username_mapping_sessions[session_id] | |
892 | self.assertEqual(session.remote_user_id, "tester") | |
893 | self.assertEqual(session.display_name, "Jonny") | |
894 | self.assertEqual(session.client_redirect_url, client_redirect_url) | |
895 | ||
896 | # the expiry time should be about 15 minutes away | |
897 | expected_expiry = self.clock.time_msec() + (15 * 60 * 1000) | |
898 | self.assertApproximates(session.expiry_time_ms, expected_expiry, tolerance=1000) | |
899 | ||
900 | # Now, submit a username to the username picker, which should serve a redirect | |
901 | # back to the client | |
902 | submit_path = f.value.location + b"/submit" | |
903 | content = urlencode({b"username": b"bobby"}).encode("utf8") | |
904 | chan = self.make_request( | |
905 | "POST", | |
906 | path=submit_path, | |
907 | content=content, | |
908 | content_is_form=True, | |
909 | custom_headers=[ | |
910 | ("Cookie", cookieheader), | |
911 | # old versions of twisted don't do form-parsing without a valid | |
912 | # content-length header. | |
913 | ("Content-Length", str(len(content))), | |
914 | ], | |
915 | ) | |
916 | self.assertEqual(chan.code, 302, chan.result) | |
917 | location_headers = chan.headers.getRawHeaders("Location") | |
918 | # ensure that the returned location starts with the requested redirect URL | |
919 | self.assertEqual( | |
920 | location_headers[0][: len(client_redirect_url)], client_redirect_url | |
921 | ) | |
922 | ||
923 | # fish the login token out of the returned redirect uri | |
924 | parts = urlparse(location_headers[0]) | |
925 | query = parse_qs(parts.query) | |
926 | login_token = query["loginToken"][0] | |
927 | ||
928 | # finally, submit the matrix login token to the login API, which gives us our | |
929 | # matrix access token, mxid, and device id. | |
930 | chan = self.make_request( | |
931 | "POST", "/login", content={"type": "m.login.token", "token": login_token}, | |
932 | ) | |
933 | self.assertEqual(chan.code, 200, chan.result) | |
934 | self.assertEqual(chan.json_body["user_id"], "@bobby:test") | |
849 | ) | |
935 | 850 | |
936 | 851 | |
937 | 852 | async def _make_callback_with_userinfo( |
947 | 862 | userinfo: the OIDC userinfo dict |
948 | 863 | client_redirect_url: the URL to redirect to on success. |
949 | 864 | """ |
865 | from synapse.handlers.oidc_handler import OidcSessionData | |
866 | ||
950 | 867 | handler = hs.get_oidc_handler() |
951 | handler._exchange_code = simple_async_mock(return_value={}) | |
952 | handler._parse_id_token = simple_async_mock(return_value=userinfo) | |
953 | handler._fetch_userinfo = simple_async_mock(return_value=userinfo) | |
868 | provider = handler._providers["oidc"] | |
869 | provider._exchange_code = simple_async_mock(return_value={}) | |
870 | provider._parse_id_token = simple_async_mock(return_value=userinfo) | |
871 | provider._fetch_userinfo = simple_async_mock(return_value=userinfo) | |
954 | 872 | |
955 | 873 | state = "state" |
956 | session = handler._generate_oidc_session_token( | |
874 | session = handler._token_generator.generate_oidc_session_token( | |
957 | 875 | state=state, |
958 | nonce="nonce", | |
959 | client_redirect_url=client_redirect_url, | |
960 | ui_auth_session_id=None, | |
876 | session_data=OidcSessionData( | |
877 | idp_id="oidc", nonce="nonce", client_redirect_url=client_redirect_url, | |
878 | ), | |
961 | 879 | ) |
962 | 880 | request = _build_callback_request("code", state, session) |
963 | 881 | |
993 | 911 | "addCookie", |
994 | 912 | "requestHeaders", |
995 | 913 | "getClientIP", |
996 | "get_user_agent", | |
914 | "getHeader", | |
997 | 915 | ] |
998 | 916 | ) |
999 | 917 | |
1002 | 920 | request.args[b"code"] = [code.encode("utf-8")] |
1003 | 921 | request.args[b"state"] = [state.encode("utf-8")] |
1004 | 922 | request.getClientIP.return_value = ip_address |
1005 | request.get_user_agent.return_value = user_agent | |
1006 | 923 | return request |
104 | 104 | "Frank", |
105 | 105 | ) |
106 | 106 | |
107 | # Set displayname to an empty string | |
108 | yield defer.ensureDeferred( | |
109 | self.handler.set_displayname( | |
110 | self.frank, synapse.types.create_requester(self.frank), "" | |
111 | ) | |
112 | ) | |
113 | ||
114 | self.assertIsNone( | |
115 | ( | |
116 | yield defer.ensureDeferred( | |
117 | self.store.get_profile_displayname(self.frank.localpart) | |
118 | ) | |
119 | ) | |
120 | ) | |
121 | ||
107 | 122 | @defer.inlineCallbacks |
108 | 123 | def test_set_my_name_if_disabled(self): |
109 | 124 | self.hs.config.enable_set_displayname = False |
222 | 237 | "http://my.server/me.png", |
223 | 238 | ) |
224 | 239 | |
240 | # Set avatar to an empty string | |
241 | yield defer.ensureDeferred( | |
242 | self.handler.set_avatar_url( | |
243 | self.frank, synapse.types.create_requester(self.frank), "", | |
244 | ) | |
245 | ) | |
246 | ||
247 | self.assertIsNone( | |
248 | ( | |
249 | yield defer.ensureDeferred( | |
250 | self.store.get_profile_avatar_url(self.frank.localpart) | |
251 | ) | |
252 | ), | |
253 | ) | |
254 | ||
225 | 255 | @defer.inlineCallbacks |
226 | 256 | def test_set_my_avatar_if_disabled(self): |
227 | 257 | self.hs.config.enable_set_avatar_url = False |
261 | 261 | |
262 | 262 | def _mock_request(): |
263 | 263 | """Returns a mock which will stand in as a SynapseRequest""" |
264 | return Mock(spec=["getClientIP", "get_user_agent"]) | |
264 | return Mock(spec=["getClientIP", "getHeader"]) |
1094 | 1094 | # Expire both caches and repeat the request |
1095 | 1095 | self.reactor.pump((10000.0,)) |
1096 | 1096 | |
1097 | # Repated the request, this time it should fail if the lookup fails. | |
1097 | # Repeat the request, this time it should fail if the lookup fails. | |
1098 | 1098 | fetch_d = defer.ensureDeferred( |
1099 | 1099 | self.well_known_resolver.get_well_known(b"testserv") |
1100 | 1100 | ) |
1129 | 1129 | content=b'{ "m.server": "' + (b"a" * WELL_KNOWN_MAX_SIZE) + b'" }', |
1130 | 1130 | ) |
1131 | 1131 | |
1132 | # The result is sucessful, but disabled delegation. | |
1132 | # The result is successful, but disabled delegation. | |
1133 | 1133 | r = self.successResultOf(fetch_d) |
1134 | 1134 | self.assertIsNone(r.delegated_server) |
1135 | 1135 |
0 | # Copyright 2021 The Matrix.org Foundation C.I.C. | |
1 | # | |
2 | # Licensed under the Apache License, Version 2.0 (the "License"); | |
3 | # you may not use this file except in compliance with the License. | |
4 | # You may obtain a copy of the License at | |
5 | # | |
6 | # http://www.apache.org/licenses/LICENSE-2.0 | |
7 | # | |
8 | # Unless required by applicable law or agreed to in writing, software | |
9 | # distributed under the License is distributed on an "AS IS" BASIS, | |
10 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |
11 | # See the License for the specific language governing permissions and | |
12 | # limitations under the License. | |
13 | ||
14 | from io import BytesIO | |
15 | ||
16 | from mock import Mock | |
17 | ||
18 | from twisted.python.failure import Failure | |
19 | from twisted.web.client import ResponseDone | |
20 | ||
21 | from synapse.http.client import BodyExceededMaxSize, read_body_with_max_size | |
22 | ||
23 | from tests.unittest import TestCase | |
24 | ||
25 | ||
26 | class ReadBodyWithMaxSizeTests(TestCase): | |
27 | def setUp(self): | |
28 | """Start reading the body, returns the response, result and proto""" | |
29 | self.response = Mock() | |
30 | self.result = BytesIO() | |
31 | self.deferred = read_body_with_max_size(self.response, self.result, 6) | |
32 | ||
33 | # Fish the protocol out of the response. | |
34 | self.protocol = self.response.deliverBody.call_args[0][0] | |
35 | self.protocol.transport = Mock() | |
36 | ||
37 | def _cleanup_error(self): | |
38 | """Ensure that the error in the Deferred is handled gracefully.""" | |
39 | called = [False] | |
40 | ||
41 | def errback(f): | |
42 | called[0] = True | |
43 | ||
44 | self.deferred.addErrback(errback) | |
45 | self.assertTrue(called[0]) | |
46 | ||
47 | def test_no_error(self): | |
48 | """A response that is NOT too large.""" | |
49 | ||
50 | # Start sending data. | |
51 | self.protocol.dataReceived(b"12345") | |
52 | # Close the connection. | |
53 | self.protocol.connectionLost(Failure(ResponseDone())) | |
54 | ||
55 | self.assertEqual(self.result.getvalue(), b"12345") | |
56 | self.assertEqual(self.deferred.result, 5) | |
57 | ||
58 | def test_too_large(self): | |
59 | """A response which is too large raises an exception.""" | |
60 | ||
61 | # Start sending data. | |
62 | self.protocol.dataReceived(b"1234567890") | |
63 | # Close the connection. | |
64 | self.protocol.connectionLost(Failure(ResponseDone())) | |
65 | ||
66 | self.assertEqual(self.result.getvalue(), b"1234567890") | |
67 | self.assertIsInstance(self.deferred.result, Failure) | |
68 | self.assertIsInstance(self.deferred.result.value, BodyExceededMaxSize) | |
69 | self._cleanup_error() | |
70 | ||
71 | def test_multiple_packets(self): | |
72 | """Data should be accummulated through mutliple packets.""" | |
73 | ||
74 | # Start sending data. | |
75 | self.protocol.dataReceived(b"12") | |
76 | self.protocol.dataReceived(b"34") | |
77 | # Close the connection. | |
78 | self.protocol.connectionLost(Failure(ResponseDone())) | |
79 | ||
80 | self.assertEqual(self.result.getvalue(), b"1234") | |
81 | self.assertEqual(self.deferred.result, 4) | |
82 | ||
83 | def test_additional_data(self): | |
84 | """A connection can receive data after being closed.""" | |
85 | ||
86 | # Start sending data. | |
87 | self.protocol.dataReceived(b"1234567890") | |
88 | self.assertIsInstance(self.deferred.result, Failure) | |
89 | self.assertIsInstance(self.deferred.result.value, BodyExceededMaxSize) | |
90 | self.protocol.transport.loseConnection.assert_called_once() | |
91 | ||
92 | # More data might have come in. | |
93 | self.protocol.dataReceived(b"1234567890") | |
94 | # Close the connection. | |
95 | self.protocol.connectionLost(Failure(ResponseDone())) | |
96 | ||
97 | self.assertEqual(self.result.getvalue(), b"1234567890") | |
98 | self.assertIsInstance(self.deferred.result, Failure) | |
99 | self.assertIsInstance(self.deferred.result.value, BodyExceededMaxSize) | |
100 | self._cleanup_error() |
11 | 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. |
12 | 12 | # See the License for the specific language governing permissions and |
13 | 13 | # limitations under the License. |
14 | from synapse.http.endpoint import parse_and_validate_server_name, parse_server_name | |
14 | from synapse.util.stringutils import parse_and_validate_server_name, parse_server_name | |
15 | 15 | |
16 | 16 | from tests import unittest |
17 | 17 |
559 | 559 | self.pump() |
560 | 560 | |
561 | 561 | f = self.failureResultOf(test_d) |
562 | self.assertIsInstance(f.value, ValueError) | |
562 | self.assertIsInstance(f.value, RequestSendFailed) |
57 | 57 | ] |
58 | 58 | |
59 | 59 | def prepare(self, reactor, clock, hs): |
60 | self.store = hs.get_datastore() | |
61 | ||
62 | 60 | self.admin_user = self.register_user("admin", "pass", admin=True) |
63 | 61 | self.admin_user_tok = self.login("admin", "pass") |
64 | 62 | |
154 | 152 | ] |
155 | 153 | |
156 | 154 | def prepare(self, reactor, clock, hs): |
157 | self.store = hs.get_datastore() | |
158 | self.hs = hs | |
159 | ||
160 | 155 | # Allow for uploading and downloading to/from the media repo |
161 | 156 | self.media_repo = hs.get_media_repository_resource() |
162 | 157 | self.download_resource = self.media_repo.children[b"download"] |
430 | 425 | |
431 | 426 | # Mark the second item as safe from quarantine. |
432 | 427 | _, media_id_2 = server_and_media_id_2.split("/") |
433 | self.get_success(self.store.mark_local_media_as_safe(media_id_2)) | |
428 | # Quarantine the media | |
429 | url = "/_synapse/admin/v1/media/protect/%s" % (urllib.parse.quote(media_id_2),) | |
430 | channel = self.make_request("POST", url, access_token=admin_user_tok) | |
431 | self.pump(1.0) | |
432 | self.assertEqual(200, int(channel.code), msg=channel.result["body"]) | |
434 | 433 | |
435 | 434 | # Quarantine all media by this user |
436 | 435 | url = "/_synapse/admin/v1/user/%s/media/quarantine" % urllib.parse.quote( |
31 | 31 | ] |
32 | 32 | |
33 | 33 | def prepare(self, reactor, clock, hs): |
34 | self.store = hs.get_datastore() | |
35 | ||
36 | 34 | self.admin_user = self.register_user("admin", "pass", admin=True) |
37 | 35 | self.admin_user_tok = self.login("admin", "pass") |
38 | 36 | |
370 | 368 | ] |
371 | 369 | |
372 | 370 | def prepare(self, reactor, clock, hs): |
373 | self.store = hs.get_datastore() | |
374 | ||
375 | 371 | self.admin_user = self.register_user("admin", "pass", admin=True) |
376 | 372 | self.admin_user_tok = self.login("admin", "pass") |
377 | 373 |
34 | 34 | ] |
35 | 35 | |
36 | 36 | def prepare(self, reactor, clock, hs): |
37 | self.handler = hs.get_device_handler() | |
38 | 37 | self.media_repo = hs.get_media_repository_resource() |
39 | 38 | self.server_name = hs.hostname |
40 | 39 | |
180 | 179 | ] |
181 | 180 | |
182 | 181 | def prepare(self, reactor, clock, hs): |
183 | self.handler = hs.get_device_handler() | |
184 | 182 | self.media_repo = hs.get_media_repository_resource() |
185 | 183 | self.server_name = hs.hostname |
186 | 184 |
604 | 604 | ] |
605 | 605 | |
606 | 606 | def prepare(self, reactor, clock, hs): |
607 | self.store = hs.get_datastore() | |
608 | ||
609 | 607 | # Create user |
610 | 608 | self.admin_user = self.register_user("admin", "pass", admin=True) |
611 | 609 | self.admin_user_tok = self.login("admin", "pass") |
30 | 30 | ] |
31 | 31 | |
32 | 32 | def prepare(self, reactor, clock, hs): |
33 | self.store = hs.get_datastore() | |
34 | 33 | self.media_repo = hs.get_media_repository_resource() |
35 | 34 | |
36 | 35 | self.admin_user = self.register_user("admin", "pass", admin=True) |
24 | 24 | import synapse.rest.admin |
25 | 25 | from synapse.api.constants import UserTypes |
26 | 26 | from synapse.api.errors import Codes, HttpResponseException, ResourceLimitError |
27 | from synapse.api.room_versions import RoomVersions | |
27 | 28 | from synapse.rest.client.v1 import login, logout, profile, room |
28 | 29 | from synapse.rest.client.v2_alpha import devices, sync |
29 | 30 | |
586 | 587 | _search_test(None, "bar", "user_id") |
587 | 588 | |
588 | 589 | |
590 | class DeactivateAccountTestCase(unittest.HomeserverTestCase): | |
591 | ||
592 | servlets = [ | |
593 | synapse.rest.admin.register_servlets, | |
594 | login.register_servlets, | |
595 | ] | |
596 | ||
597 | def prepare(self, reactor, clock, hs): | |
598 | self.store = hs.get_datastore() | |
599 | ||
600 | self.admin_user = self.register_user("admin", "pass", admin=True) | |
601 | self.admin_user_tok = self.login("admin", "pass") | |
602 | ||
603 | self.other_user = self.register_user("user", "pass", displayname="User1") | |
604 | self.other_user_token = self.login("user", "pass") | |
605 | self.url_other_user = "/_synapse/admin/v2/users/%s" % urllib.parse.quote( | |
606 | self.other_user | |
607 | ) | |
608 | self.url = "/_synapse/admin/v1/deactivate/%s" % urllib.parse.quote( | |
609 | self.other_user | |
610 | ) | |
611 | ||
612 | # set attributes for user | |
613 | self.get_success( | |
614 | self.store.set_profile_avatar_url("user", "mxc://servername/mediaid") | |
615 | ) | |
616 | self.get_success( | |
617 | self.store.user_add_threepid("@user:test", "email", "foo@bar.com", 0, 0) | |
618 | ) | |
619 | ||
620 | def test_no_auth(self): | |
621 | """ | |
622 | Try to deactivate users without authentication. | |
623 | """ | |
624 | channel = self.make_request("POST", self.url, b"{}") | |
625 | ||
626 | self.assertEqual(401, int(channel.result["code"]), msg=channel.result["body"]) | |
627 | self.assertEqual(Codes.MISSING_TOKEN, channel.json_body["errcode"]) | |
628 | ||
629 | def test_requester_is_not_admin(self): | |
630 | """ | |
631 | If the user is not a server admin, an error is returned. | |
632 | """ | |
633 | url = "/_synapse/admin/v1/deactivate/@bob:test" | |
634 | ||
635 | channel = self.make_request("POST", url, access_token=self.other_user_token) | |
636 | ||
637 | self.assertEqual(403, int(channel.result["code"]), msg=channel.result["body"]) | |
638 | self.assertEqual("You are not a server admin", channel.json_body["error"]) | |
639 | ||
640 | channel = self.make_request( | |
641 | "POST", url, access_token=self.other_user_token, content=b"{}", | |
642 | ) | |
643 | ||
644 | self.assertEqual(403, int(channel.result["code"]), msg=channel.result["body"]) | |
645 | self.assertEqual("You are not a server admin", channel.json_body["error"]) | |
646 | ||
647 | def test_user_does_not_exist(self): | |
648 | """ | |
649 | Tests that deactivation for a user that does not exist returns a 404 | |
650 | """ | |
651 | ||
652 | channel = self.make_request( | |
653 | "POST", | |
654 | "/_synapse/admin/v1/deactivate/@unknown_person:test", | |
655 | access_token=self.admin_user_tok, | |
656 | ) | |
657 | ||
658 | self.assertEqual(404, channel.code, msg=channel.json_body) | |
659 | self.assertEqual(Codes.NOT_FOUND, channel.json_body["errcode"]) | |
660 | ||
661 | def test_erase_is_not_bool(self): | |
662 | """ | |
663 | If parameter `erase` is not boolean, return an error | |
664 | """ | |
665 | body = json.dumps({"erase": "False"}) | |
666 | ||
667 | channel = self.make_request( | |
668 | "POST", | |
669 | self.url, | |
670 | content=body.encode(encoding="utf_8"), | |
671 | access_token=self.admin_user_tok, | |
672 | ) | |
673 | ||
674 | self.assertEqual(400, int(channel.result["code"]), msg=channel.result["body"]) | |
675 | self.assertEqual(Codes.BAD_JSON, channel.json_body["errcode"]) | |
676 | ||
677 | def test_user_is_not_local(self): | |
678 | """ | |
679 | Tests that deactivation for a user that is not a local returns a 400 | |
680 | """ | |
681 | url = "/_synapse/admin/v1/deactivate/@unknown_person:unknown_domain" | |
682 | ||
683 | channel = self.make_request("POST", url, access_token=self.admin_user_tok) | |
684 | ||
685 | self.assertEqual(400, channel.code, msg=channel.json_body) | |
686 | self.assertEqual("Can only deactivate local users", channel.json_body["error"]) | |
687 | ||
688 | def test_deactivate_user_erase_true(self): | |
689 | """ | |
690 | Test deactivating an user and set `erase` to `true` | |
691 | """ | |
692 | ||
693 | # Get user | |
694 | channel = self.make_request( | |
695 | "GET", self.url_other_user, access_token=self.admin_user_tok, | |
696 | ) | |
697 | ||
698 | self.assertEqual(200, int(channel.result["code"]), msg=channel.result["body"]) | |
699 | self.assertEqual("@user:test", channel.json_body["name"]) | |
700 | self.assertEqual(False, channel.json_body["deactivated"]) | |
701 | self.assertEqual("foo@bar.com", channel.json_body["threepids"][0]["address"]) | |
702 | self.assertEqual("mxc://servername/mediaid", channel.json_body["avatar_url"]) | |
703 | self.assertEqual("User1", channel.json_body["displayname"]) | |
704 | ||
705 | # Deactivate user | |
706 | body = json.dumps({"erase": True}) | |
707 | ||
708 | channel = self.make_request( | |
709 | "POST", | |
710 | self.url, | |
711 | access_token=self.admin_user_tok, | |
712 | content=body.encode(encoding="utf_8"), | |
713 | ) | |
714 | ||
715 | self.assertEqual(200, int(channel.result["code"]), msg=channel.result["body"]) | |
716 | ||
717 | # Get user | |
718 | channel = self.make_request( | |
719 | "GET", self.url_other_user, access_token=self.admin_user_tok, | |
720 | ) | |
721 | ||
722 | self.assertEqual(200, int(channel.result["code"]), msg=channel.result["body"]) | |
723 | self.assertEqual("@user:test", channel.json_body["name"]) | |
724 | self.assertEqual(True, channel.json_body["deactivated"]) | |
725 | self.assertEqual(0, len(channel.json_body["threepids"])) | |
726 | self.assertIsNone(channel.json_body["avatar_url"]) | |
727 | self.assertIsNone(channel.json_body["displayname"]) | |
728 | ||
729 | self._is_erased("@user:test", True) | |
730 | ||
731 | def test_deactivate_user_erase_false(self): | |
732 | """ | |
733 | Test deactivating an user and set `erase` to `false` | |
734 | """ | |
735 | ||
736 | # Get user | |
737 | channel = self.make_request( | |
738 | "GET", self.url_other_user, access_token=self.admin_user_tok, | |
739 | ) | |
740 | ||
741 | self.assertEqual(200, int(channel.result["code"]), msg=channel.result["body"]) | |
742 | self.assertEqual("@user:test", channel.json_body["name"]) | |
743 | self.assertEqual(False, channel.json_body["deactivated"]) | |
744 | self.assertEqual("foo@bar.com", channel.json_body["threepids"][0]["address"]) | |
745 | self.assertEqual("mxc://servername/mediaid", channel.json_body["avatar_url"]) | |
746 | self.assertEqual("User1", channel.json_body["displayname"]) | |
747 | ||
748 | # Deactivate user | |
749 | body = json.dumps({"erase": False}) | |
750 | ||
751 | channel = self.make_request( | |
752 | "POST", | |
753 | self.url, | |
754 | access_token=self.admin_user_tok, | |
755 | content=body.encode(encoding="utf_8"), | |
756 | ) | |
757 | ||
758 | self.assertEqual(200, int(channel.result["code"]), msg=channel.result["body"]) | |
759 | ||
760 | # Get user | |
761 | channel = self.make_request( | |
762 | "GET", self.url_other_user, access_token=self.admin_user_tok, | |
763 | ) | |
764 | ||
765 | self.assertEqual(200, int(channel.result["code"]), msg=channel.result["body"]) | |
766 | self.assertEqual("@user:test", channel.json_body["name"]) | |
767 | self.assertEqual(True, channel.json_body["deactivated"]) | |
768 | self.assertEqual(0, len(channel.json_body["threepids"])) | |
769 | self.assertEqual("mxc://servername/mediaid", channel.json_body["avatar_url"]) | |
770 | self.assertEqual("User1", channel.json_body["displayname"]) | |
771 | ||
772 | self._is_erased("@user:test", False) | |
773 | ||
774 | def _is_erased(self, user_id: str, expect: bool) -> None: | |
775 | """Assert that the user is erased or not | |
776 | """ | |
777 | d = self.store.is_user_erased(user_id) | |
778 | if expect: | |
779 | self.assertTrue(self.get_success(d)) | |
780 | else: | |
781 | self.assertFalse(self.get_success(d)) | |
782 | ||
783 | ||
589 | 784 | class UserRestTestCase(unittest.HomeserverTestCase): |
590 | 785 | |
591 | 786 | servlets = [ |
985 | 1180 | Test deactivating another user. |
986 | 1181 | """ |
987 | 1182 | |
1183 | # set attributes for user | |
1184 | self.get_success( | |
1185 | self.store.set_profile_avatar_url("user", "mxc://servername/mediaid") | |
1186 | ) | |
1187 | self.get_success( | |
1188 | self.store.user_add_threepid("@user:test", "email", "foo@bar.com", 0, 0) | |
1189 | ) | |
1190 | ||
1191 | # Get user | |
1192 | channel = self.make_request( | |
1193 | "GET", self.url_other_user, access_token=self.admin_user_tok, | |
1194 | ) | |
1195 | ||
1196 | self.assertEqual(200, int(channel.result["code"]), msg=channel.result["body"]) | |
1197 | self.assertEqual("@user:test", channel.json_body["name"]) | |
1198 | self.assertEqual(False, channel.json_body["deactivated"]) | |
1199 | self.assertEqual("foo@bar.com", channel.json_body["threepids"][0]["address"]) | |
1200 | self.assertEqual("mxc://servername/mediaid", channel.json_body["avatar_url"]) | |
1201 | self.assertEqual("User", channel.json_body["displayname"]) | |
1202 | ||
988 | 1203 | # Deactivate user |
989 | 1204 | body = json.dumps({"deactivated": True}) |
990 | 1205 | |
998 | 1213 | self.assertEqual(200, int(channel.result["code"]), msg=channel.result["body"]) |
999 | 1214 | self.assertEqual("@user:test", channel.json_body["name"]) |
1000 | 1215 | self.assertEqual(True, channel.json_body["deactivated"]) |
1216 | self.assertEqual(0, len(channel.json_body["threepids"])) | |
1217 | self.assertEqual("mxc://servername/mediaid", channel.json_body["avatar_url"]) | |
1218 | self.assertEqual("User", channel.json_body["displayname"]) | |
1001 | 1219 | # the user is deactivated, the threepid will be deleted |
1002 | 1220 | |
1003 | 1221 | # Get user |
1008 | 1226 | self.assertEqual(200, int(channel.result["code"]), msg=channel.result["body"]) |
1009 | 1227 | self.assertEqual("@user:test", channel.json_body["name"]) |
1010 | 1228 | self.assertEqual(True, channel.json_body["deactivated"]) |
1229 | self.assertEqual(0, len(channel.json_body["threepids"])) | |
1230 | self.assertEqual("mxc://servername/mediaid", channel.json_body["avatar_url"]) | |
1231 | self.assertEqual("User", channel.json_body["displayname"]) | |
1011 | 1232 | |
1012 | 1233 | @override_config({"user_directory": {"enabled": True, "search_all_users": True}}) |
1013 | 1234 | def test_change_name_deactivate_user_user_directory(self): |
1203 | 1424 | ] |
1204 | 1425 | |
1205 | 1426 | def prepare(self, reactor, clock, hs): |
1206 | self.store = hs.get_datastore() | |
1207 | ||
1208 | 1427 | self.admin_user = self.register_user("admin", "pass", admin=True) |
1209 | 1428 | self.admin_user_tok = self.login("admin", "pass") |
1210 | 1429 | |
1235 | 1454 | |
1236 | 1455 | def test_user_does_not_exist(self): |
1237 | 1456 | """ |
1238 | Tests that a lookup for a user that does not exist returns a 404 | |
1457 | Tests that a lookup for a user that does not exist returns an empty list | |
1239 | 1458 | """ |
1240 | 1459 | url = "/_synapse/admin/v1/users/@unknown_person:test/joined_rooms" |
1241 | 1460 | channel = self.make_request("GET", url, access_token=self.admin_user_tok,) |
1242 | 1461 | |
1243 | self.assertEqual(404, channel.code, msg=channel.json_body) | |
1244 | self.assertEqual(Codes.NOT_FOUND, channel.json_body["errcode"]) | |
1462 | self.assertEqual(200, channel.code, msg=channel.json_body) | |
1463 | self.assertEqual(0, channel.json_body["total"]) | |
1464 | self.assertEqual(0, len(channel.json_body["joined_rooms"])) | |
1245 | 1465 | |
1246 | 1466 | def test_user_is_not_local(self): |
1247 | 1467 | """ |
1248 | Tests that a lookup for a user that is not a local returns a 400 | |
1468 | Tests that a lookup for a user that is not a local and participates in no conversation returns an empty list | |
1249 | 1469 | """ |
1250 | 1470 | url = "/_synapse/admin/v1/users/@unknown_person:unknown_domain/joined_rooms" |
1251 | 1471 | |
1252 | 1472 | channel = self.make_request("GET", url, access_token=self.admin_user_tok,) |
1253 | 1473 | |
1254 | self.assertEqual(400, channel.code, msg=channel.json_body) | |
1255 | self.assertEqual("Can only lookup local users", channel.json_body["error"]) | |
1474 | self.assertEqual(200, channel.code, msg=channel.json_body) | |
1475 | self.assertEqual(0, channel.json_body["total"]) | |
1476 | self.assertEqual(0, len(channel.json_body["joined_rooms"])) | |
1256 | 1477 | |
1257 | 1478 | def test_no_memberships(self): |
1258 | 1479 | """ |
1282 | 1503 | self.assertEqual(200, channel.code, msg=channel.json_body) |
1283 | 1504 | self.assertEqual(number_rooms, channel.json_body["total"]) |
1284 | 1505 | self.assertEqual(number_rooms, len(channel.json_body["joined_rooms"])) |
1506 | ||
1507 | def test_get_rooms_with_nonlocal_user(self): | |
1508 | """ | |
1509 | Tests that a normal lookup for rooms is successful with a non-local user | |
1510 | """ | |
1511 | ||
1512 | other_user_tok = self.login("user", "pass") | |
1513 | event_builder_factory = self.hs.get_event_builder_factory() | |
1514 | event_creation_handler = self.hs.get_event_creation_handler() | |
1515 | storage = self.hs.get_storage() | |
1516 | ||
1517 | # Create two rooms, one with a local user only and one with both a local | |
1518 | # and remote user. | |
1519 | self.helper.create_room_as(self.other_user, tok=other_user_tok) | |
1520 | local_and_remote_room_id = self.helper.create_room_as( | |
1521 | self.other_user, tok=other_user_tok | |
1522 | ) | |
1523 | ||
1524 | # Add a remote user to the room. | |
1525 | builder = event_builder_factory.for_room_version( | |
1526 | RoomVersions.V1, | |
1527 | { | |
1528 | "type": "m.room.member", | |
1529 | "sender": "@joiner:remote_hs", | |
1530 | "state_key": "@joiner:remote_hs", | |
1531 | "room_id": local_and_remote_room_id, | |
1532 | "content": {"membership": "join"}, | |
1533 | }, | |
1534 | ) | |
1535 | ||
1536 | event, context = self.get_success( | |
1537 | event_creation_handler.create_new_client_event(builder) | |
1538 | ) | |
1539 | ||
1540 | self.get_success(storage.persistence.persist_event(event, context)) | |
1541 | ||
1542 | # Now get rooms | |
1543 | url = "/_synapse/admin/v1/users/@joiner:remote_hs/joined_rooms" | |
1544 | channel = self.make_request("GET", url, access_token=self.admin_user_tok,) | |
1545 | ||
1546 | self.assertEqual(200, channel.code, msg=channel.json_body) | |
1547 | self.assertEqual(1, channel.json_body["total"]) | |
1548 | self.assertEqual([local_and_remote_room_id], channel.json_body["joined_rooms"]) | |
1285 | 1549 | |
1286 | 1550 | |
1287 | 1551 | class PushersRestTestCase(unittest.HomeserverTestCase): |
1400 | 1664 | ] |
1401 | 1665 | |
1402 | 1666 | def prepare(self, reactor, clock, hs): |
1403 | self.store = hs.get_datastore() | |
1404 | 1667 | self.media_repo = hs.get_media_repository_resource() |
1405 | 1668 | |
1406 | 1669 | self.admin_user = self.register_user("admin", "pass", admin=True) |
1867 | 2130 | ] |
1868 | 2131 | |
1869 | 2132 | def prepare(self, reactor, clock, hs): |
1870 | self.store = hs.get_datastore() | |
1871 | ||
1872 | 2133 | self.admin_user = self.register_user("admin", "pass", admin=True) |
1873 | 2134 | self.admin_user_tok = self.login("admin", "pass") |
1874 | 2135 |
0 | import json | |
0 | # -*- coding: utf-8 -*- | |
1 | # Copyright 2019-2021 The Matrix.org Foundation C.I.C. | |
2 | # | |
3 | # Licensed under the Apache License, Version 2.0 (the "License"); | |
4 | # you may not use this file except in compliance with the License. | |
5 | # You may obtain a copy of the License at | |
6 | # | |
7 | # http://www.apache.org/licenses/LICENSE-2.0 | |
8 | # | |
9 | # Unless required by applicable law or agreed to in writing, software | |
10 | # distributed under the License is distributed on an "AS IS" BASIS, | |
11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |
12 | # See the License for the specific language governing permissions and | |
13 | # limitations under the License. | |
14 | ||
1 | 15 | import time |
2 | 16 | import urllib.parse |
17 | from typing import Any, Dict, Union | |
18 | from urllib.parse import urlencode | |
3 | 19 | |
4 | 20 | from mock import Mock |
5 | 21 | |
6 | import jwt | |
22 | import pymacaroons | |
23 | ||
24 | from twisted.web.resource import Resource | |
7 | 25 | |
8 | 26 | import synapse.rest.admin |
9 | 27 | from synapse.appservice import ApplicationService |
10 | 28 | from synapse.rest.client.v1 import login, logout |
11 | 29 | from synapse.rest.client.v2_alpha import devices, register |
12 | 30 | from synapse.rest.client.v2_alpha.account import WhoamiRestServlet |
31 | from synapse.rest.synapse.client.pick_idp import PickIdpResource | |
32 | from synapse.rest.synapse.client.pick_username import pick_username_resource | |
33 | from synapse.types import create_requester | |
13 | 34 | |
14 | 35 | from tests import unittest |
15 | from tests.unittest import override_config | |
36 | from tests.handlers.test_oidc import HAS_OIDC | |
37 | from tests.handlers.test_saml import has_saml2 | |
38 | from tests.rest.client.v1.utils import TEST_OIDC_AUTH_ENDPOINT, TEST_OIDC_CONFIG | |
39 | from tests.test_utils.html_parsers import TestHtmlParser | |
40 | from tests.unittest import HomeserverTestCase, override_config, skip_unless | |
41 | ||
42 | try: | |
43 | import jwt | |
44 | ||
45 | HAS_JWT = True | |
46 | except ImportError: | |
47 | HAS_JWT = False | |
48 | ||
49 | ||
50 | # public_base_url used in some tests | |
51 | BASE_URL = "https://synapse/" | |
52 | ||
53 | # CAS server used in some tests | |
54 | CAS_SERVER = "https://fake.test" | |
55 | ||
56 | # just enough to tell pysaml2 where to redirect to | |
57 | SAML_SERVER = "https://test.saml.server/idp/sso" | |
58 | TEST_SAML_METADATA = """ | |
59 | <md:EntityDescriptor xmlns:md="urn:oasis:names:tc:SAML:2.0:metadata"> | |
60 | <md:IDPSSODescriptor protocolSupportEnumeration="urn:oasis:names:tc:SAML:2.0:protocol"> | |
61 | <md:SingleSignOnService Binding="urn:oasis:names:tc:SAML:2.0:bindings:HTTP-Redirect" Location="%(SAML_SERVER)s"/> | |
62 | </md:IDPSSODescriptor> | |
63 | </md:EntityDescriptor> | |
64 | """ % { | |
65 | "SAML_SERVER": SAML_SERVER, | |
66 | } | |
16 | 67 | |
17 | 68 | LOGIN_URL = b"/_matrix/client/r0/login" |
18 | 69 | TEST_URL = b"/_matrix/client/r0/account/whoami" |
70 | ||
71 | # a (valid) url with some annoying characters in. %3D is =, %26 is &, %2B is + | |
72 | TEST_CLIENT_REDIRECT_URL = 'https://x?<ab c>&q"+%3D%2B"="fö%26=o"' | |
73 | ||
74 | # the query params in TEST_CLIENT_REDIRECT_URL | |
75 | EXPECTED_CLIENT_REDIRECT_URL_PARAMS = [("<ab c>", ""), ('q" =+"', '"fö&=o"')] | |
19 | 76 | |
20 | 77 | |
21 | 78 | class LoginRestServletTestCase(unittest.HomeserverTestCase): |
310 | 367 | self.assertEquals(channel.result["code"], b"200", channel.result) |
311 | 368 | |
312 | 369 | |
313 | class CASTestCase(unittest.HomeserverTestCase): | |
370 | @skip_unless(has_saml2 and HAS_OIDC, "Requires SAML2 and OIDC") | |
371 | class MultiSSOTestCase(unittest.HomeserverTestCase): | |
372 | """Tests for homeservers with multiple SSO providers enabled""" | |
314 | 373 | |
315 | 374 | servlets = [ |
316 | 375 | login.register_servlets, |
317 | 376 | ] |
318 | 377 | |
378 | def default_config(self) -> Dict[str, Any]: | |
379 | config = super().default_config() | |
380 | ||
381 | config["public_baseurl"] = BASE_URL | |
382 | ||
383 | config["cas_config"] = { | |
384 | "enabled": True, | |
385 | "server_url": CAS_SERVER, | |
386 | "service_url": "https://matrix.goodserver.com:8448", | |
387 | } | |
388 | ||
389 | config["saml2_config"] = { | |
390 | "sp_config": { | |
391 | "metadata": {"inline": [TEST_SAML_METADATA]}, | |
392 | # use the XMLSecurity backend to avoid relying on xmlsec1 | |
393 | "crypto_backend": "XMLSecurity", | |
394 | }, | |
395 | } | |
396 | ||
397 | # default OIDC provider | |
398 | config["oidc_config"] = TEST_OIDC_CONFIG | |
399 | ||
400 | # additional OIDC providers | |
401 | config["oidc_providers"] = [ | |
402 | { | |
403 | "idp_id": "idp1", | |
404 | "idp_name": "IDP1", | |
405 | "discover": False, | |
406 | "issuer": "https://issuer1", | |
407 | "client_id": "test-client-id", | |
408 | "client_secret": "test-client-secret", | |
409 | "scopes": ["profile"], | |
410 | "authorization_endpoint": "https://issuer1/auth", | |
411 | "token_endpoint": "https://issuer1/token", | |
412 | "userinfo_endpoint": "https://issuer1/userinfo", | |
413 | "user_mapping_provider": { | |
414 | "config": {"localpart_template": "{{ user.sub }}"} | |
415 | }, | |
416 | } | |
417 | ] | |
418 | return config | |
419 | ||
420 | def create_resource_dict(self) -> Dict[str, Resource]: | |
421 | from synapse.rest.oidc import OIDCResource | |
422 | ||
423 | d = super().create_resource_dict() | |
424 | d["/_synapse/client/pick_idp"] = PickIdpResource(self.hs) | |
425 | d["/_synapse/oidc"] = OIDCResource(self.hs) | |
426 | return d | |
427 | ||
428 | def test_multi_sso_redirect(self): | |
429 | """/login/sso/redirect should redirect to an identity picker""" | |
430 | # first hit the redirect url, which should redirect to our idp picker | |
431 | channel = self.make_request( | |
432 | "GET", | |
433 | "/_matrix/client/r0/login/sso/redirect?redirectUrl=" | |
434 | + urllib.parse.quote_plus(TEST_CLIENT_REDIRECT_URL), | |
435 | ) | |
436 | self.assertEqual(channel.code, 302, channel.result) | |
437 | uri = channel.headers.getRawHeaders("Location")[0] | |
438 | ||
439 | # hitting that picker should give us some HTML | |
440 | channel = self.make_request("GET", uri) | |
441 | self.assertEqual(channel.code, 200, channel.result) | |
442 | ||
443 | # parse the form to check it has fields assumed elsewhere in this class | |
444 | p = TestHtmlParser() | |
445 | p.feed(channel.result["body"].decode("utf-8")) | |
446 | p.close() | |
447 | ||
448 | self.assertCountEqual(p.radios["idp"], ["cas", "oidc", "oidc-idp1", "saml"]) | |
449 | ||
450 | self.assertEqual(p.hiddens["redirectUrl"], TEST_CLIENT_REDIRECT_URL) | |
451 | ||
452 | def test_multi_sso_redirect_to_cas(self): | |
453 | """If CAS is chosen, should redirect to the CAS server""" | |
454 | ||
455 | channel = self.make_request( | |
456 | "GET", | |
457 | "/_synapse/client/pick_idp?redirectUrl=" | |
458 | + urllib.parse.quote_plus(TEST_CLIENT_REDIRECT_URL) | |
459 | + "&idp=cas", | |
460 | shorthand=False, | |
461 | ) | |
462 | self.assertEqual(channel.code, 302, channel.result) | |
463 | cas_uri = channel.headers.getRawHeaders("Location")[0] | |
464 | cas_uri_path, cas_uri_query = cas_uri.split("?", 1) | |
465 | ||
466 | # it should redirect us to the login page of the cas server | |
467 | self.assertEqual(cas_uri_path, CAS_SERVER + "/login") | |
468 | ||
469 | # check that the redirectUrl is correctly encoded in the service param - ie, the | |
470 | # place that CAS will redirect to | |
471 | cas_uri_params = urllib.parse.parse_qs(cas_uri_query) | |
472 | service_uri = cas_uri_params["service"][0] | |
473 | _, service_uri_query = service_uri.split("?", 1) | |
474 | service_uri_params = urllib.parse.parse_qs(service_uri_query) | |
475 | self.assertEqual(service_uri_params["redirectUrl"][0], TEST_CLIENT_REDIRECT_URL) | |
476 | ||
477 | def test_multi_sso_redirect_to_saml(self): | |
478 | """If SAML is chosen, should redirect to the SAML server""" | |
479 | channel = self.make_request( | |
480 | "GET", | |
481 | "/_synapse/client/pick_idp?redirectUrl=" | |
482 | + urllib.parse.quote_plus(TEST_CLIENT_REDIRECT_URL) | |
483 | + "&idp=saml", | |
484 | ) | |
485 | self.assertEqual(channel.code, 302, channel.result) | |
486 | saml_uri = channel.headers.getRawHeaders("Location")[0] | |
487 | saml_uri_path, saml_uri_query = saml_uri.split("?", 1) | |
488 | ||
489 | # it should redirect us to the login page of the SAML server | |
490 | self.assertEqual(saml_uri_path, SAML_SERVER) | |
491 | ||
492 | # the RelayState is used to carry the client redirect url | |
493 | saml_uri_params = urllib.parse.parse_qs(saml_uri_query) | |
494 | relay_state_param = saml_uri_params["RelayState"][0] | |
495 | self.assertEqual(relay_state_param, TEST_CLIENT_REDIRECT_URL) | |
496 | ||
497 | def test_login_via_oidc(self): | |
498 | """If OIDC is chosen, should redirect to the OIDC auth endpoint""" | |
499 | ||
500 | # pick the default OIDC provider | |
501 | channel = self.make_request( | |
502 | "GET", | |
503 | "/_synapse/client/pick_idp?redirectUrl=" | |
504 | + urllib.parse.quote_plus(TEST_CLIENT_REDIRECT_URL) | |
505 | + "&idp=oidc", | |
506 | ) | |
507 | self.assertEqual(channel.code, 302, channel.result) | |
508 | oidc_uri = channel.headers.getRawHeaders("Location")[0] | |
509 | oidc_uri_path, oidc_uri_query = oidc_uri.split("?", 1) | |
510 | ||
511 | # it should redirect us to the auth page of the OIDC server | |
512 | self.assertEqual(oidc_uri_path, TEST_OIDC_AUTH_ENDPOINT) | |
513 | ||
514 | # ... and should have set a cookie including the redirect url | |
515 | cookies = dict( | |
516 | h.split(";")[0].split("=", maxsplit=1) | |
517 | for h in channel.headers.getRawHeaders("Set-Cookie") | |
518 | ) | |
519 | ||
520 | oidc_session_cookie = cookies["oidc_session"] | |
521 | macaroon = pymacaroons.Macaroon.deserialize(oidc_session_cookie) | |
522 | self.assertEqual( | |
523 | self._get_value_from_macaroon(macaroon, "client_redirect_url"), | |
524 | TEST_CLIENT_REDIRECT_URL, | |
525 | ) | |
526 | ||
527 | channel = self.helper.complete_oidc_auth(oidc_uri, cookies, {"sub": "user1"}) | |
528 | ||
529 | # that should serve a confirmation page | |
530 | self.assertEqual(channel.code, 200, channel.result) | |
531 | self.assertTrue( | |
532 | channel.headers.getRawHeaders("Content-Type")[-1].startswith("text/html") | |
533 | ) | |
534 | p = TestHtmlParser() | |
535 | p.feed(channel.text_body) | |
536 | p.close() | |
537 | ||
538 | # ... which should contain our redirect link | |
539 | self.assertEqual(len(p.links), 1) | |
540 | path, query = p.links[0].split("?", 1) | |
541 | self.assertEqual(path, "https://x") | |
542 | ||
543 | # it will have url-encoded the params properly, so we'll have to parse them | |
544 | params = urllib.parse.parse_qsl( | |
545 | query, keep_blank_values=True, strict_parsing=True, errors="strict" | |
546 | ) | |
547 | self.assertEqual(params[0:2], EXPECTED_CLIENT_REDIRECT_URL_PARAMS) | |
548 | self.assertEqual(params[2][0], "loginToken") | |
549 | ||
550 | # finally, submit the matrix login token to the login API, which gives us our | |
551 | # matrix access token, mxid, and device id. | |
552 | login_token = params[2][1] | |
553 | chan = self.make_request( | |
554 | "POST", "/login", content={"type": "m.login.token", "token": login_token}, | |
555 | ) | |
556 | self.assertEqual(chan.code, 200, chan.result) | |
557 | self.assertEqual(chan.json_body["user_id"], "@user1:test") | |
558 | ||
559 | def test_multi_sso_redirect_to_unknown(self): | |
560 | """An unknown IdP should cause a 400""" | |
561 | channel = self.make_request( | |
562 | "GET", "/_synapse/client/pick_idp?redirectUrl=http://x&idp=xyz", | |
563 | ) | |
564 | self.assertEqual(channel.code, 400, channel.result) | |
565 | ||
566 | @staticmethod | |
567 | def _get_value_from_macaroon(macaroon: pymacaroons.Macaroon, key: str) -> str: | |
568 | prefix = key + " = " | |
569 | for caveat in macaroon.caveats: | |
570 | if caveat.caveat_id.startswith(prefix): | |
571 | return caveat.caveat_id[len(prefix) :] | |
572 | raise ValueError("No %s caveat in macaroon" % (key,)) | |
573 | ||
574 | ||
575 | class CASTestCase(unittest.HomeserverTestCase): | |
576 | ||
577 | servlets = [ | |
578 | login.register_servlets, | |
579 | ] | |
580 | ||
319 | 581 | def make_homeserver(self, reactor, clock): |
320 | 582 | self.base_url = "https://matrix.goodserver.com/" |
321 | 583 | self.redirect_path = "_synapse/client/login/sso/redirect/confirm" |
323 | 585 | config = self.default_config() |
324 | 586 | config["cas_config"] = { |
325 | 587 | "enabled": True, |
326 | "server_url": "https://fake.test", | |
588 | "server_url": CAS_SERVER, | |
327 | 589 | "service_url": "https://matrix.goodserver.com:8448", |
328 | 590 | } |
329 | 591 | |
384 | 646 | channel = self.make_request("GET", cas_ticket_url) |
385 | 647 | |
386 | 648 | # Test that the response is HTML. |
387 | self.assertEqual(channel.code, 200) | |
649 | self.assertEqual(channel.code, 200, channel.result) | |
388 | 650 | content_type_header_value = "" |
389 | 651 | for header in channel.result.get("headers", []): |
390 | 652 | if header[0] == b"Content-Type": |
409 | 671 | } |
410 | 672 | ) |
411 | 673 | def test_cas_redirect_whitelisted(self): |
412 | """Tests that the SSO login flow serves a redirect to a whitelisted url | |
413 | """ | |
674 | """Tests that the SSO login flow serves a redirect to a whitelisted url""" | |
414 | 675 | self._test_redirect("https://legit-site.com/") |
415 | 676 | |
416 | 677 | @override_config({"public_baseurl": "https://example.com"}) |
441 | 702 | |
442 | 703 | # Deactivate the account. |
443 | 704 | self.get_success( |
444 | self.deactivate_account_handler.deactivate_account(self.user_id, False) | |
705 | self.deactivate_account_handler.deactivate_account( | |
706 | self.user_id, False, create_requester(self.user_id) | |
707 | ) | |
445 | 708 | ) |
446 | 709 | |
447 | 710 | # Request the CAS ticket. |
458 | 721 | self.assertIn(b"SSO account deactivated", channel.result["body"]) |
459 | 722 | |
460 | 723 | |
724 | @skip_unless(HAS_JWT, "requires jwt") | |
461 | 725 | class JWTTestCase(unittest.HomeserverTestCase): |
462 | 726 | servlets = [ |
463 | 727 | synapse.rest.admin.register_servlets_for_client_rest_resource, |
474 | 738 | self.hs.config.jwt_algorithm = self.jwt_algorithm |
475 | 739 | return self.hs |
476 | 740 | |
477 | def jwt_encode(self, token: str, secret: str = jwt_secret) -> str: | |
741 | def jwt_encode(self, payload: Dict[str, Any], secret: str = jwt_secret) -> str: | |
478 | 742 | # PyJWT 2.0.0 changed the return type of jwt.encode from bytes to str. |
479 | result = jwt.encode(token, secret, self.jwt_algorithm) | |
743 | result = jwt.encode( | |
744 | payload, secret, self.jwt_algorithm | |
745 | ) # type: Union[str, bytes] | |
480 | 746 | if isinstance(result, bytes): |
481 | 747 | return result.decode("ascii") |
482 | 748 | return result |
483 | 749 | |
484 | 750 | def jwt_login(self, *args): |
485 | params = json.dumps( | |
486 | {"type": "org.matrix.login.jwt", "token": self.jwt_encode(*args)} | |
487 | ) | |
751 | params = {"type": "org.matrix.login.jwt", "token": self.jwt_encode(*args)} | |
488 | 752 | channel = self.make_request(b"POST", LOGIN_URL, params) |
489 | 753 | return channel |
490 | 754 | |
616 | 880 | ) |
617 | 881 | |
618 | 882 | def test_login_no_token(self): |
619 | params = json.dumps({"type": "org.matrix.login.jwt"}) | |
883 | params = {"type": "org.matrix.login.jwt"} | |
620 | 884 | channel = self.make_request(b"POST", LOGIN_URL, params) |
621 | 885 | self.assertEqual(channel.result["code"], b"403", channel.result) |
622 | 886 | self.assertEqual(channel.json_body["errcode"], "M_FORBIDDEN") |
626 | 890 | # The JWTPubKeyTestCase is a complement to JWTTestCase where we instead use |
627 | 891 | # RSS256, with a public key configured in synapse as "jwt_secret", and tokens |
628 | 892 | # signed by the private key. |
893 | @skip_unless(HAS_JWT, "requires jwt") | |
629 | 894 | class JWTPubKeyTestCase(unittest.HomeserverTestCase): |
630 | 895 | servlets = [ |
631 | 896 | login.register_servlets, |
683 | 948 | self.hs.config.jwt_algorithm = "RS256" |
684 | 949 | return self.hs |
685 | 950 | |
686 | def jwt_encode(self, token: str, secret: str = jwt_privatekey) -> str: | |
951 | def jwt_encode(self, payload: Dict[str, Any], secret: str = jwt_privatekey) -> str: | |
687 | 952 | # PyJWT 2.0.0 changed the return type of jwt.encode from bytes to str. |
688 | result = jwt.encode(token, secret, "RS256") | |
953 | result = jwt.encode(payload, secret, "RS256") # type: Union[bytes,str] | |
689 | 954 | if isinstance(result, bytes): |
690 | 955 | return result.decode("ascii") |
691 | 956 | return result |
692 | 957 | |
693 | 958 | def jwt_login(self, *args): |
694 | params = json.dumps( | |
695 | {"type": "org.matrix.login.jwt", "token": self.jwt_encode(*args)} | |
696 | ) | |
959 | params = {"type": "org.matrix.login.jwt", "token": self.jwt_encode(*args)} | |
697 | 960 | channel = self.make_request(b"POST", LOGIN_URL, params) |
698 | 961 | return channel |
699 | 962 | |
763 | 1026 | return self.hs |
764 | 1027 | |
765 | 1028 | def test_login_appservice_user(self): |
766 | """Test that an appservice user can use /login | |
767 | """ | |
1029 | """Test that an appservice user can use /login""" | |
768 | 1030 | self.register_as_user(AS_USER) |
769 | 1031 | |
770 | 1032 | params = { |
778 | 1040 | self.assertEquals(channel.result["code"], b"200", channel.result) |
779 | 1041 | |
780 | 1042 | def test_login_appservice_user_bot(self): |
781 | """Test that the appservice bot can use /login | |
782 | """ | |
1043 | """Test that the appservice bot can use /login""" | |
783 | 1044 | self.register_as_user(AS_USER) |
784 | 1045 | |
785 | 1046 | params = { |
793 | 1054 | self.assertEquals(channel.result["code"], b"200", channel.result) |
794 | 1055 | |
795 | 1056 | def test_login_appservice_wrong_user(self): |
796 | """Test that non-as users cannot login with the as token | |
797 | """ | |
1057 | """Test that non-as users cannot login with the as token""" | |
798 | 1058 | self.register_as_user(AS_USER) |
799 | 1059 | |
800 | 1060 | params = { |
808 | 1068 | self.assertEquals(channel.result["code"], b"403", channel.result) |
809 | 1069 | |
810 | 1070 | def test_login_appservice_wrong_as(self): |
811 | """Test that as users cannot login with wrong as token | |
812 | """ | |
1071 | """Test that as users cannot login with wrong as token""" | |
813 | 1072 | self.register_as_user(AS_USER) |
814 | 1073 | |
815 | 1074 | params = { |
824 | 1083 | |
825 | 1084 | def test_login_appservice_no_token(self): |
826 | 1085 | """Test that users must provide a token when using the appservice |
827 | login method | |
1086 | login method | |
828 | 1087 | """ |
829 | 1088 | self.register_as_user(AS_USER) |
830 | 1089 | |
835 | 1094 | channel = self.make_request(b"POST", LOGIN_URL, params) |
836 | 1095 | |
837 | 1096 | self.assertEquals(channel.result["code"], b"401", channel.result) |
1097 | ||
1098 | ||
1099 | @skip_unless(HAS_OIDC, "requires OIDC") | |
1100 | class UsernamePickerTestCase(HomeserverTestCase): | |
1101 | """Tests for the username picker flow of SSO login""" | |
1102 | ||
1103 | servlets = [login.register_servlets] | |
1104 | ||
1105 | def default_config(self): | |
1106 | config = super().default_config() | |
1107 | config["public_baseurl"] = BASE_URL | |
1108 | ||
1109 | config["oidc_config"] = {} | |
1110 | config["oidc_config"].update(TEST_OIDC_CONFIG) | |
1111 | config["oidc_config"]["user_mapping_provider"] = { | |
1112 | "config": {"display_name_template": "{{ user.displayname }}"} | |
1113 | } | |
1114 | ||
1115 | # whitelist this client URI so we redirect straight to it rather than | |
1116 | # serving a confirmation page | |
1117 | config["sso"] = {"client_whitelist": ["https://x"]} | |
1118 | return config | |
1119 | ||
1120 | def create_resource_dict(self) -> Dict[str, Resource]: | |
1121 | from synapse.rest.oidc import OIDCResource | |
1122 | ||
1123 | d = super().create_resource_dict() | |
1124 | d["/_synapse/client/pick_username"] = pick_username_resource(self.hs) | |
1125 | d["/_synapse/oidc"] = OIDCResource(self.hs) | |
1126 | return d | |
1127 | ||
1128 | def test_username_picker(self): | |
1129 | """Test the happy path of a username picker flow.""" | |
1130 | ||
1131 | # do the start of the login flow | |
1132 | channel = self.helper.auth_via_oidc( | |
1133 | {"sub": "tester", "displayname": "Jonny"}, TEST_CLIENT_REDIRECT_URL | |
1134 | ) | |
1135 | ||
1136 | # that should redirect to the username picker | |
1137 | self.assertEqual(channel.code, 302, channel.result) | |
1138 | picker_url = channel.headers.getRawHeaders("Location")[0] | |
1139 | self.assertEqual(picker_url, "/_synapse/client/pick_username") | |
1140 | ||
1141 | # ... with a username_mapping_session cookie | |
1142 | cookies = {} # type: Dict[str,str] | |
1143 | channel.extract_cookies(cookies) | |
1144 | self.assertIn("username_mapping_session", cookies) | |
1145 | session_id = cookies["username_mapping_session"] | |
1146 | ||
1147 | # introspect the sso handler a bit to check that the username mapping session | |
1148 | # looks ok. | |
1149 | username_mapping_sessions = self.hs.get_sso_handler()._username_mapping_sessions | |
1150 | self.assertIn( | |
1151 | session_id, username_mapping_sessions, "session id not found in map", | |
1152 | ) | |
1153 | session = username_mapping_sessions[session_id] | |
1154 | self.assertEqual(session.remote_user_id, "tester") | |
1155 | self.assertEqual(session.display_name, "Jonny") | |
1156 | self.assertEqual(session.client_redirect_url, TEST_CLIENT_REDIRECT_URL) | |
1157 | ||
1158 | # the expiry time should be about 15 minutes away | |
1159 | expected_expiry = self.clock.time_msec() + (15 * 60 * 1000) | |
1160 | self.assertApproximates(session.expiry_time_ms, expected_expiry, tolerance=1000) | |
1161 | ||
1162 | # Now, submit a username to the username picker, which should serve a redirect | |
1163 | # back to the client | |
1164 | submit_path = picker_url + "/submit" | |
1165 | content = urlencode({b"username": b"bobby"}).encode("utf8") | |
1166 | chan = self.make_request( | |
1167 | "POST", | |
1168 | path=submit_path, | |
1169 | content=content, | |
1170 | content_is_form=True, | |
1171 | custom_headers=[ | |
1172 | ("Cookie", "username_mapping_session=" + session_id), | |
1173 | # old versions of twisted don't do form-parsing without a valid | |
1174 | # content-length header. | |
1175 | ("Content-Length", str(len(content))), | |
1176 | ], | |
1177 | ) | |
1178 | self.assertEqual(chan.code, 302, chan.result) | |
1179 | location_headers = chan.headers.getRawHeaders("Location") | |
1180 | # ensure that the returned location matches the requested redirect URL | |
1181 | path, query = location_headers[0].split("?", 1) | |
1182 | self.assertEqual(path, "https://x") | |
1183 | ||
1184 | # it will have url-encoded the params properly, so we'll have to parse them | |
1185 | params = urllib.parse.parse_qsl( | |
1186 | query, keep_blank_values=True, strict_parsing=True, errors="strict" | |
1187 | ) | |
1188 | self.assertEqual(params[0:2], EXPECTED_CLIENT_REDIRECT_URL_PARAMS) | |
1189 | self.assertEqual(params[2][0], "loginToken") | |
1190 | ||
1191 | # fish the login token out of the returned redirect uri | |
1192 | login_token = params[2][1] | |
1193 | ||
1194 | # finally, submit the matrix login token to the login API, which gives us our | |
1195 | # matrix access token, mxid, and device id. | |
1196 | chan = self.make_request( | |
1197 | "POST", "/login", content={"type": "m.login.token", "token": login_token}, | |
1198 | ) | |
1199 | self.assertEqual(chan.code, 200, chan.result) | |
1200 | self.assertEqual(chan.json_body["user_id"], "@bobby:test") |
28 | 28 | from synapse.rest import admin |
29 | 29 | from synapse.rest.client.v1 import directory, login, profile, room |
30 | 30 | from synapse.rest.client.v2_alpha import account |
31 | from synapse.types import JsonDict, RoomAlias, UserID | |
31 | from synapse.types import JsonDict, RoomAlias, UserID, create_requester | |
32 | 32 | from synapse.util.stringutils import random_string |
33 | 33 | |
34 | 34 | from tests import unittest |
1686 | 1686 | |
1687 | 1687 | deactivate_account_handler = self.hs.get_deactivate_account_handler() |
1688 | 1688 | self.get_success( |
1689 | deactivate_account_handler.deactivate_account(self.user_id, erase_data=True) | |
1689 | deactivate_account_handler.deactivate_account( | |
1690 | self.user_id, True, create_requester(self.user_id) | |
1691 | ) | |
1690 | 1692 | ) |
1691 | 1693 | |
1692 | 1694 | # Invite another user in the room. This is needed because messages will be |
1 | 1 | # Copyright 2014-2016 OpenMarket Ltd |
2 | 2 | # Copyright 2017 Vector Creations Ltd |
3 | 3 | # Copyright 2018-2019 New Vector Ltd |
4 | # Copyright 2019-2020 The Matrix.org Foundation C.I.C. | |
4 | # Copyright 2019-2021 The Matrix.org Foundation C.I.C. | |
5 | 5 | # |
6 | 6 | # Licensed under the Apache License, Version 2.0 (the "License"); |
7 | 7 | # you may not use this file except in compliance with the License. |
19 | 19 | import re |
20 | 20 | import time |
21 | 21 | import urllib.parse |
22 | from typing import Any, Dict, Optional | |
22 | from typing import Any, Dict, Mapping, MutableMapping, Optional | |
23 | 23 | |
24 | 24 | from mock import patch |
25 | 25 | |
31 | 31 | from synapse.api.constants import Membership |
32 | 32 | from synapse.types import JsonDict |
33 | 33 | |
34 | from tests.server import FakeSite, make_request | |
34 | from tests.server import FakeChannel, FakeSite, make_request | |
35 | 35 | from tests.test_utils import FakeResponse |
36 | from tests.test_utils.html_parsers import TestHtmlParser | |
36 | 37 | |
37 | 38 | |
38 | 39 | @attr.s |
361 | 362 | the normal places. |
362 | 363 | """ |
363 | 364 | client_redirect_url = "https://x" |
364 | ||
365 | # first hit the redirect url (which will issue a cookie and state) | |
365 | channel = self.auth_via_oidc({"sub": remote_user_id}, client_redirect_url) | |
366 | ||
367 | # expect a confirmation page | |
368 | assert channel.code == 200, channel.result | |
369 | ||
370 | # fish the matrix login token out of the body of the confirmation page | |
371 | m = re.search( | |
372 | 'a href="%s.*loginToken=([^"]*)"' % (client_redirect_url,), | |
373 | channel.text_body, | |
374 | ) | |
375 | assert m, channel.text_body | |
376 | login_token = m.group(1) | |
377 | ||
378 | # finally, submit the matrix login token to the login API, which gives us our | |
379 | # matrix access token and device id. | |
366 | 380 | channel = make_request( |
367 | 381 | self.hs.get_reactor(), |
368 | 382 | self.site, |
369 | "GET", | |
370 | "/login/sso/redirect?redirectUrl=" + client_redirect_url, | |
371 | ) | |
372 | # that will redirect to the OIDC IdP, but we skip that and go straight | |
383 | "POST", | |
384 | "/login", | |
385 | content={"type": "m.login.token", "token": login_token}, | |
386 | ) | |
387 | assert channel.code == 200 | |
388 | return channel.json_body | |
389 | ||
390 | def auth_via_oidc( | |
391 | self, | |
392 | user_info_dict: JsonDict, | |
393 | client_redirect_url: Optional[str] = None, | |
394 | ui_auth_session_id: Optional[str] = None, | |
395 | ) -> FakeChannel: | |
396 | """Perform an OIDC authentication flow via a mock OIDC provider. | |
397 | ||
398 | This can be used for either login or user-interactive auth. | |
399 | ||
400 | Starts by making a request to the relevant synapse redirect endpoint, which is | |
401 | expected to serve a 302 to the OIDC provider. We then make a request to the | |
402 | OIDC callback endpoint, intercepting the HTTP requests that will get sent back | |
403 | to the OIDC provider. | |
404 | ||
405 | Requires that "oidc_config" in the homeserver config be set appropriately | |
406 | (TEST_OIDC_CONFIG is a suitable example) - and by implication, needs a | |
407 | "public_base_url". | |
408 | ||
409 | Also requires the login servlet and the OIDC callback resource to be mounted at | |
410 | the normal places. | |
411 | ||
412 | Args: | |
413 | user_info_dict: the remote userinfo that the OIDC provider should present. | |
414 | Typically this should be '{"sub": "<remote user id>"}'. | |
415 | client_redirect_url: for a login flow, the client redirect URL to pass to | |
416 | the login redirect endpoint | |
417 | ui_auth_session_id: if set, we will perform a UI Auth flow. The session id | |
418 | of the UI auth. | |
419 | ||
420 | Returns: | |
421 | A FakeChannel containing the result of calling the OIDC callback endpoint. | |
422 | Note that the response code may be a 200, 302 or 400 depending on how things | |
423 | went. | |
424 | """ | |
425 | ||
426 | cookies = {} | |
427 | ||
428 | # if we're doing a ui auth, hit the ui auth redirect endpoint | |
429 | if ui_auth_session_id: | |
430 | # can't set the client redirect url for UI Auth | |
431 | assert client_redirect_url is None | |
432 | oauth_uri = self.initiate_sso_ui_auth(ui_auth_session_id, cookies) | |
433 | else: | |
434 | # otherwise, hit the login redirect endpoint | |
435 | oauth_uri = self.initiate_sso_login(client_redirect_url, cookies) | |
436 | ||
437 | # we now have a URI for the OIDC IdP, but we skip that and go straight | |
373 | 438 | # back to synapse's OIDC callback resource. However, we do need the "state" |
374 | # param that synapse passes to the IdP via query params, and the cookie that | |
375 | # synapse passes to the client. | |
376 | assert channel.code == 302 | |
377 | oauth_uri = channel.headers.getRawHeaders("Location")[0] | |
378 | params = urllib.parse.parse_qs(urllib.parse.urlparse(oauth_uri).query) | |
379 | redirect_uri = "%s?%s" % ( | |
439 | # param that synapse passes to the IdP via query params, as well as the cookie | |
440 | # that synapse passes to the client. | |
441 | ||
442 | oauth_uri_path, _ = oauth_uri.split("?", 1) | |
443 | assert oauth_uri_path == TEST_OIDC_AUTH_ENDPOINT, ( | |
444 | "unexpected SSO URI " + oauth_uri_path | |
445 | ) | |
446 | return self.complete_oidc_auth(oauth_uri, cookies, user_info_dict) | |
447 | ||
448 | def complete_oidc_auth( | |
449 | self, oauth_uri: str, cookies: Mapping[str, str], user_info_dict: JsonDict, | |
450 | ) -> FakeChannel: | |
451 | """Mock out an OIDC authentication flow | |
452 | ||
453 | Assumes that an OIDC auth has been initiated by one of initiate_sso_login or | |
454 | initiate_sso_ui_auth; completes the OIDC bits of the flow by making a request to | |
455 | Synapse's OIDC callback endpoint, intercepting the HTTP requests that will get | |
456 | sent back to the OIDC provider. | |
457 | ||
458 | Requires the OIDC callback resource to be mounted at the normal place. | |
459 | ||
460 | Args: | |
461 | oauth_uri: the OIDC URI returned by synapse's redirect endpoint (ie, | |
462 | from initiate_sso_login or initiate_sso_ui_auth). | |
463 | cookies: the cookies set by synapse's redirect endpoint, which will be | |
464 | sent back to the callback endpoint. | |
465 | user_info_dict: the remote userinfo that the OIDC provider should present. | |
466 | Typically this should be '{"sub": "<remote user id>"}'. | |
467 | ||
468 | Returns: | |
469 | A FakeChannel containing the result of calling the OIDC callback endpoint. | |
470 | """ | |
471 | _, oauth_uri_qs = oauth_uri.split("?", 1) | |
472 | params = urllib.parse.parse_qs(oauth_uri_qs) | |
473 | callback_uri = "%s?%s" % ( | |
380 | 474 | urllib.parse.urlparse(params["redirect_uri"][0]).path, |
381 | 475 | urllib.parse.urlencode({"state": params["state"][0], "code": "TEST_CODE"}), |
382 | 476 | ) |
383 | cookies = {} | |
384 | for h in channel.headers.getRawHeaders("Set-Cookie"): | |
385 | parts = h.split(";") | |
386 | k, v = parts[0].split("=", maxsplit=1) | |
387 | cookies[k] = v | |
388 | 477 | |
389 | 478 | # before we hit the callback uri, stub out some methods in the http client so |
390 | 479 | # that we don't have to handle full HTTPS requests. |
391 | ||
392 | 480 | # (expected url, json response) pairs, in the order we expect them. |
393 | 481 | expected_requests = [ |
394 | 482 | # first we get a hit to the token endpoint, which we tell to return |
395 | 483 | # a dummy OIDC access token |
396 | ("https://issuer.test/token", {"access_token": "TEST"}), | |
484 | (TEST_OIDC_TOKEN_ENDPOINT, {"access_token": "TEST"}), | |
397 | 485 | # and then one to the user_info endpoint, which returns our remote user id. |
398 | ("https://issuer.test/userinfo", {"sub": remote_user_id}), | |
486 | (TEST_OIDC_USERINFO_ENDPOINT, user_info_dict), | |
399 | 487 | ] |
400 | 488 | |
401 | 489 | async def mock_req(method: str, uri: str, data=None, headers=None): |
412 | 500 | self.hs.get_reactor(), |
413 | 501 | self.site, |
414 | 502 | "GET", |
415 | redirect_uri, | |
503 | callback_uri, | |
416 | 504 | custom_headers=[ |
417 | 505 | ("Cookie", "%s=%s" % (k, v)) for (k, v) in cookies.items() |
418 | 506 | ], |
419 | 507 | ) |
420 | ||
421 | # expect a confirmation page | |
422 | assert channel.code == 200 | |
423 | ||
424 | # fish the matrix login token out of the body of the confirmation page | |
425 | m = re.search( | |
426 | 'a href="%s.*loginToken=([^"]*)"' % (client_redirect_url,), | |
427 | channel.result["body"].decode("utf-8"), | |
428 | ) | |
429 | assert m | |
430 | login_token = m.group(1) | |
431 | ||
432 | # finally, submit the matrix login token to the login API, which gives us our | |
433 | # matrix access token and device id. | |
508 | return channel | |
509 | ||
510 | def initiate_sso_login( | |
511 | self, client_redirect_url: Optional[str], cookies: MutableMapping[str, str] | |
512 | ) -> str: | |
513 | """Make a request to the login-via-sso redirect endpoint, and return the target | |
514 | ||
515 | Assumes that exactly one SSO provider has been configured. Requires the login | |
516 | servlet to be mounted. | |
517 | ||
518 | Args: | |
519 | client_redirect_url: the client redirect URL to pass to the login redirect | |
520 | endpoint | |
521 | cookies: any cookies returned will be added to this dict | |
522 | ||
523 | Returns: | |
524 | the URI that the client gets redirected to (ie, the SSO server) | |
525 | """ | |
526 | params = {} | |
527 | if client_redirect_url: | |
528 | params["redirectUrl"] = client_redirect_url | |
529 | ||
530 | # hit the redirect url (which will issue a cookie and state) | |
434 | 531 | channel = make_request( |
435 | 532 | self.hs.get_reactor(), |
436 | 533 | self.site, |
437 | "POST", | |
438 | "/login", | |
439 | content={"type": "m.login.token", "token": login_token}, | |
440 | ) | |
441 | assert channel.code == 200 | |
442 | return channel.json_body | |
534 | "GET", | |
535 | "/_matrix/client/r0/login/sso/redirect?" + urllib.parse.urlencode(params), | |
536 | ) | |
537 | ||
538 | assert channel.code == 302 | |
539 | channel.extract_cookies(cookies) | |
540 | return channel.headers.getRawHeaders("Location")[0] | |
541 | ||
542 | def initiate_sso_ui_auth( | |
543 | self, ui_auth_session_id: str, cookies: MutableMapping[str, str] | |
544 | ) -> str: | |
545 | """Make a request to the ui-auth-via-sso endpoint, and return the target | |
546 | ||
547 | Assumes that exactly one SSO provider has been configured. Requires the | |
548 | AuthRestServlet to be mounted. | |
549 | ||
550 | Args: | |
551 | ui_auth_session_id: the session id of the UI auth | |
552 | cookies: any cookies returned will be added to this dict | |
553 | ||
554 | Returns: | |
555 | the URI that the client gets linked to (ie, the SSO server) | |
556 | """ | |
557 | sso_redirect_endpoint = ( | |
558 | "/_matrix/client/r0/auth/m.login.sso/fallback/web?" | |
559 | + urllib.parse.urlencode({"session": ui_auth_session_id}) | |
560 | ) | |
561 | # hit the redirect url (which will issue a cookie and state) | |
562 | channel = make_request( | |
563 | self.hs.get_reactor(), self.site, "GET", sso_redirect_endpoint | |
564 | ) | |
565 | # that should serve a confirmation page | |
566 | assert channel.code == 200, channel.text_body | |
567 | channel.extract_cookies(cookies) | |
568 | ||
569 | # parse the confirmation page to fish out the link. | |
570 | p = TestHtmlParser() | |
571 | p.feed(channel.text_body) | |
572 | p.close() | |
573 | assert len(p.links) == 1, "not exactly one link in confirmation page" | |
574 | oauth_uri = p.links[0] | |
575 | return oauth_uri | |
443 | 576 | |
444 | 577 | |
445 | 578 | # an 'oidc_config' suitable for login_via_oidc. |
579 | TEST_OIDC_AUTH_ENDPOINT = "https://issuer.test/auth" | |
580 | TEST_OIDC_TOKEN_ENDPOINT = "https://issuer.test/token" | |
581 | TEST_OIDC_USERINFO_ENDPOINT = "https://issuer.test/userinfo" | |
446 | 582 | TEST_OIDC_CONFIG = { |
447 | 583 | "enabled": True, |
448 | 584 | "discover": False, |
450 | 586 | "client_id": "test-client-id", |
451 | 587 | "client_secret": "test-client-secret", |
452 | 588 | "scopes": ["profile"], |
453 | "authorization_endpoint": "https://z", | |
454 | "token_endpoint": "https://issuer.test/token", | |
455 | "userinfo_endpoint": "https://issuer.test/userinfo", | |
589 | "authorization_endpoint": TEST_OIDC_AUTH_ENDPOINT, | |
590 | "token_endpoint": TEST_OIDC_TOKEN_ENDPOINT, | |
591 | "userinfo_endpoint": TEST_OIDC_USERINFO_ENDPOINT, | |
456 | 592 | "user_mapping_provider": {"config": {"localpart_template": "{{ user.sub }}"}}, |
457 | 593 | } |
0 | 0 | # -*- coding: utf-8 -*- |
1 | 1 | # Copyright 2018 New Vector |
2 | # Copyright 2020-2021 The Matrix.org Foundation C.I.C | |
2 | 3 | # |
3 | 4 | # Licensed under the Apache License, Version 2.0 (the "License"); |
4 | 5 | # you may not use this file except in compliance with the License. |
11 | 12 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. |
12 | 13 | # See the License for the specific language governing permissions and |
13 | 14 | # limitations under the License. |
14 | ||
15 | 15 | from typing import Union |
16 | 16 | |
17 | 17 | from twisted.internet.defer import succeed |
25 | 25 | from synapse.types import JsonDict, UserID |
26 | 26 | |
27 | 27 | from tests import unittest |
28 | from tests.handlers.test_oidc import HAS_OIDC | |
28 | 29 | from tests.rest.client.v1.utils import TEST_OIDC_CONFIG |
29 | 30 | from tests.server import FakeChannel |
31 | from tests.unittest import override_config, skip_unless | |
30 | 32 | |
31 | 33 | |
32 | 34 | class DummyRecaptchaChecker(UserInteractiveAuthChecker): |
157 | 159 | |
158 | 160 | def default_config(self): |
159 | 161 | config = super().default_config() |
160 | ||
161 | # we enable OIDC as a way of testing SSO flows | |
162 | oidc_config = {} | |
163 | oidc_config.update(TEST_OIDC_CONFIG) | |
164 | oidc_config["allow_existing_users"] = True | |
165 | ||
166 | config["oidc_config"] = oidc_config | |
167 | 162 | config["public_baseurl"] = "https://synapse.test" |
163 | ||
164 | if HAS_OIDC: | |
165 | # we enable OIDC as a way of testing SSO flows | |
166 | oidc_config = {} | |
167 | oidc_config.update(TEST_OIDC_CONFIG) | |
168 | oidc_config["allow_existing_users"] = True | |
169 | config["oidc_config"] = oidc_config | |
170 | ||
168 | 171 | return config |
169 | 172 | |
170 | 173 | def create_resource_dict(self): |
171 | 174 | resource_dict = super().create_resource_dict() |
172 | # mount the OIDC resource at /_synapse/oidc | |
173 | resource_dict["/_synapse/oidc"] = OIDCResource(self.hs) | |
175 | if HAS_OIDC: | |
176 | # mount the OIDC resource at /_synapse/oidc | |
177 | resource_dict["/_synapse/oidc"] = OIDCResource(self.hs) | |
174 | 178 | return resource_dict |
175 | 179 | |
176 | 180 | def prepare(self, reactor, clock, hs): |
379 | 383 | # Note that *no auth* information is provided, not even a session iD! |
380 | 384 | self.delete_device(self.user_tok, self.device_id, 200) |
381 | 385 | |
386 | @skip_unless(HAS_OIDC, "requires OIDC") | |
387 | @override_config({"oidc_config": TEST_OIDC_CONFIG}) | |
388 | def test_ui_auth_via_sso(self): | |
389 | """Test a successful UI Auth flow via SSO | |
390 | ||
391 | This includes: | |
392 | * hitting the UIA SSO redirect endpoint | |
393 | * checking it serves a confirmation page which links to the OIDC provider | |
394 | * calling back to the synapse oidc callback | |
395 | * checking that the original operation succeeds | |
396 | """ | |
397 | ||
398 | # log the user in | |
399 | remote_user_id = UserID.from_string(self.user).localpart | |
400 | login_resp = self.helper.login_via_oidc(remote_user_id) | |
401 | self.assertEqual(login_resp["user_id"], self.user) | |
402 | ||
403 | # initiate a UI Auth process by attempting to delete the device | |
404 | channel = self.delete_device(self.user_tok, self.device_id, 401) | |
405 | ||
406 | # check that SSO is offered | |
407 | flows = channel.json_body["flows"] | |
408 | self.assertIn({"stages": ["m.login.sso"]}, flows) | |
409 | ||
410 | # run the UIA-via-SSO flow | |
411 | session_id = channel.json_body["session"] | |
412 | channel = self.helper.auth_via_oidc( | |
413 | {"sub": remote_user_id}, ui_auth_session_id=session_id | |
414 | ) | |
415 | ||
416 | # that should serve a confirmation page | |
417 | self.assertEqual(channel.code, 200, channel.result) | |
418 | ||
419 | # and now the delete request should succeed. | |
420 | self.delete_device( | |
421 | self.user_tok, self.device_id, 200, body={"auth": {"session": session_id}}, | |
422 | ) | |
423 | ||
424 | @skip_unless(HAS_OIDC, "requires OIDC") | |
425 | @override_config({"oidc_config": TEST_OIDC_CONFIG}) | |
382 | 426 | def test_does_not_offer_password_for_sso_user(self): |
383 | 427 | login_resp = self.helper.login_via_oidc("username") |
384 | 428 | user_tok = login_resp["access_token"] |
392 | 436 | self.assertEqual(flows, [{"stages": ["m.login.sso"]}]) |
393 | 437 | |
394 | 438 | def test_does_not_offer_sso_for_password_user(self): |
395 | # now call the device deletion API: we should get the option to auth with SSO | |
396 | # and not password. | |
397 | 439 | channel = self.delete_device(self.user_tok, self.device_id, 401) |
398 | 440 | |
399 | 441 | flows = channel.json_body["flows"] |
400 | 442 | self.assertEqual(flows, [{"stages": ["m.login.password"]}]) |
401 | 443 | |
444 | @skip_unless(HAS_OIDC, "requires OIDC") | |
445 | @override_config({"oidc_config": TEST_OIDC_CONFIG}) | |
402 | 446 | def test_offers_both_flows_for_upgraded_user(self): |
403 | 447 | """A user that had a password and then logged in with SSO should get both flows |
404 | 448 | """ |
412 | 456 | self.assertIn({"stages": ["m.login.password"]}, flows) |
413 | 457 | self.assertIn({"stages": ["m.login.sso"]}, flows) |
414 | 458 | self.assertEqual(len(flows), 2) |
459 | ||
460 | @skip_unless(HAS_OIDC, "requires OIDC") | |
461 | @override_config({"oidc_config": TEST_OIDC_CONFIG}) | |
462 | def test_ui_auth_fails_for_incorrect_sso_user(self): | |
463 | """If the user tries to authenticate with the wrong SSO user, they get an error | |
464 | """ | |
465 | # log the user in | |
466 | login_resp = self.helper.login_via_oidc(UserID.from_string(self.user).localpart) | |
467 | self.assertEqual(login_resp["user_id"], self.user) | |
468 | ||
469 | # start a UI Auth flow by attempting to delete a device | |
470 | channel = self.delete_device(self.user_tok, self.device_id, 401) | |
471 | ||
472 | flows = channel.json_body["flows"] | |
473 | self.assertIn({"stages": ["m.login.sso"]}, flows) | |
474 | session_id = channel.json_body["session"] | |
475 | ||
476 | # do the OIDC auth, but auth as the wrong user | |
477 | channel = self.helper.auth_via_oidc( | |
478 | {"sub": "wrong_user"}, ui_auth_session_id=session_id | |
479 | ) | |
480 | ||
481 | # that should return a failure message | |
482 | self.assertSubstring("We were unable to validate", channel.text_body) | |
483 | ||
484 | # ... and the delete op should now fail with a 403 | |
485 | self.delete_device( | |
486 | self.user_tok, self.device_id, 403, body={"auth": {"session": session_id}} | |
487 | ) |
25 | 25 | from tests import unittest |
26 | 26 | from tests.server import FakeTransport |
27 | 27 | |
28 | try: | |
29 | import lxml | |
30 | except ImportError: | |
31 | lxml = None | |
32 | ||
28 | 33 | |
29 | 34 | class URLPreviewTests(unittest.HomeserverTestCase): |
35 | if not lxml: | |
36 | skip = "url preview feature requires lxml" | |
30 | 37 | |
31 | 38 | hijack_auth = True |
32 | 39 | user_id = "@test:user" |
39 | 39 | "m.identity_server": {"base_url": "https://testis"}, |
40 | 40 | }, |
41 | 41 | ) |
42 | ||
43 | def test_well_known_no_public_baseurl(self): | |
44 | self.hs.config.public_baseurl = None | |
45 | ||
46 | channel = self.make_request( | |
47 | "GET", "/.well-known/matrix/client", shorthand=False | |
48 | ) | |
49 | ||
50 | self.assertEqual(channel.code, 404) |
1 | 1 | import logging |
2 | 2 | from collections import deque |
3 | 3 | from io import SEEK_END, BytesIO |
4 | from typing import Callable, Iterable, Optional, Tuple, Union | |
4 | from typing import Callable, Iterable, MutableMapping, Optional, Tuple, Union | |
5 | 5 | |
6 | 6 | import attr |
7 | 7 | from typing_extensions import Deque |
50 | 50 | |
51 | 51 | @property |
52 | 52 | def json_body(self): |
53 | if not self.result: | |
54 | raise Exception("No result yet.") | |
55 | return json.loads(self.result["body"].decode("utf8")) | |
53 | return json.loads(self.text_body) | |
54 | ||
55 | @property | |
56 | def text_body(self) -> str: | |
57 | """The body of the result, utf-8-decoded. | |
58 | ||
59 | Raises an exception if the request has not yet completed. | |
60 | """ | |
61 | if not self.is_finished: | |
62 | raise Exception("Request not yet completed") | |
63 | return self.result["body"].decode("utf8") | |
64 | ||
65 | def is_finished(self) -> bool: | |
66 | """check if the response has been completely received""" | |
67 | return self.result.get("done", False) | |
56 | 68 | |
57 | 69 | @property |
58 | 70 | def code(self): |
61 | 73 | return int(self.result["code"]) |
62 | 74 | |
63 | 75 | @property |
64 | def headers(self): | |
76 | def headers(self) -> Headers: | |
65 | 77 | if not self.result: |
66 | 78 | raise Exception("No result yet.") |
67 | 79 | h = Headers() |
123 | 135 | self._reactor.run() |
124 | 136 | x = 0 |
125 | 137 | |
126 | while not self.result.get("done"): | |
138 | while not self.is_finished(): | |
127 | 139 | # If there's a producer, tell it to resume producing so we get content |
128 | 140 | if self._producer: |
129 | 141 | self._producer.resumeProducing() |
134 | 146 | raise TimedOutException("Timed out waiting for request to finish.") |
135 | 147 | |
136 | 148 | self._reactor.advance(0.1) |
149 | ||
150 | def extract_cookies(self, cookies: MutableMapping[str, str]) -> None: | |
151 | """Process the contents of any Set-Cookie headers in the response | |
152 | ||
153 | Any cookines found are added to the given dict | |
154 | """ | |
155 | for h in self.headers.getRawHeaders("Set-Cookie"): | |
156 | parts = h.split(";") | |
157 | k, v = parts[0].split("=", maxsplit=1) | |
158 | cookies[k] = v | |
137 | 159 | |
138 | 160 | |
139 | 161 | class FakeSite: |
0 | # -*- coding: utf-8 -*- | |
1 | # Copyright 2021 The Matrix.org Foundation C.I.C. | |
2 | # | |
3 | # Licensed under the Apache License, Version 2.0 (the "License"); | |
4 | # you may not use this file except in compliance with the License. | |
5 | # You may obtain a copy of the License at | |
6 | # | |
7 | # http://www.apache.org/licenses/LICENSE-2.0 | |
8 | # | |
9 | # Unless required by applicable law or agreed to in writing, software | |
10 | # distributed under the License is distributed on an "AS IS" BASIS, | |
11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |
12 | # See the License for the specific language governing permissions and | |
13 | # limitations under the License. | |
14 | ||
15 | from typing import Iterable, Set | |
16 | ||
17 | from synapse.api.constants import AccountDataTypes | |
18 | ||
19 | from tests import unittest | |
20 | ||
21 | ||
22 | class IgnoredUsersTestCase(unittest.HomeserverTestCase): | |
23 | def prepare(self, hs, reactor, clock): | |
24 | self.store = self.hs.get_datastore() | |
25 | self.user = "@user:test" | |
26 | ||
27 | def _update_ignore_list( | |
28 | self, *ignored_user_ids: Iterable[str], ignorer_user_id: str = None | |
29 | ) -> None: | |
30 | """Update the account data to block the given users.""" | |
31 | if ignorer_user_id is None: | |
32 | ignorer_user_id = self.user | |
33 | ||
34 | self.get_success( | |
35 | self.store.add_account_data_for_user( | |
36 | ignorer_user_id, | |
37 | AccountDataTypes.IGNORED_USER_LIST, | |
38 | {"ignored_users": {u: {} for u in ignored_user_ids}}, | |
39 | ) | |
40 | ) | |
41 | ||
42 | def assert_ignorers( | |
43 | self, ignored_user_id: str, expected_ignorer_user_ids: Set[str] | |
44 | ) -> None: | |
45 | self.assertEqual( | |
46 | self.get_success(self.store.ignored_by(ignored_user_id)), | |
47 | expected_ignorer_user_ids, | |
48 | ) | |
49 | ||
50 | def test_ignoring_users(self): | |
51 | """Basic adding/removing of users from the ignore list.""" | |
52 | self._update_ignore_list("@other:test", "@another:remote") | |
53 | ||
54 | # Check a user which no one ignores. | |
55 | self.assert_ignorers("@user:test", set()) | |
56 | ||
57 | # Check a local user which is ignored. | |
58 | self.assert_ignorers("@other:test", {self.user}) | |
59 | ||
60 | # Check a remote user which is ignored. | |
61 | self.assert_ignorers("@another:remote", {self.user}) | |
62 | ||
63 | # Add one user, remove one user, and leave one user. | |
64 | self._update_ignore_list("@foo:test", "@another:remote") | |
65 | ||
66 | # Check the removed user. | |
67 | self.assert_ignorers("@other:test", set()) | |
68 | ||
69 | # Check the added user. | |
70 | self.assert_ignorers("@foo:test", {self.user}) | |
71 | ||
72 | # Check the removed user. | |
73 | self.assert_ignorers("@another:remote", {self.user}) | |
74 | ||
75 | def test_caching(self): | |
76 | """Ensure that caching works properly between different users.""" | |
77 | # The first user ignores a user. | |
78 | self._update_ignore_list("@other:test") | |
79 | self.assert_ignorers("@other:test", {self.user}) | |
80 | ||
81 | # The second user ignores them. | |
82 | self._update_ignore_list("@other:test", ignorer_user_id="@second:test") | |
83 | self.assert_ignorers("@other:test", {self.user, "@second:test"}) | |
84 | ||
85 | # The first user un-ignores them. | |
86 | self._update_ignore_list() | |
87 | self.assert_ignorers("@other:test", {"@second:test"}) | |
88 | ||
89 | def test_invalid_data(self): | |
90 | """Invalid data ends up clearing out the ignored users list.""" | |
91 | # Add some data and ensure it is there. | |
92 | self._update_ignore_list("@other:test") | |
93 | self.assert_ignorers("@other:test", {self.user}) | |
94 | ||
95 | # No ignored_users key. | |
96 | self.get_success( | |
97 | self.store.add_account_data_for_user( | |
98 | self.user, AccountDataTypes.IGNORED_USER_LIST, {}, | |
99 | ) | |
100 | ) | |
101 | ||
102 | # No one ignores the user now. | |
103 | self.assert_ignorers("@other:test", set()) | |
104 | ||
105 | # Add some data and ensure it is there. | |
106 | self._update_ignore_list("@other:test") | |
107 | self.assert_ignorers("@other:test", {self.user}) | |
108 | ||
109 | # Invalid data. | |
110 | self.get_success( | |
111 | self.store.add_account_data_for_user( | |
112 | self.user, | |
113 | AccountDataTypes.IGNORED_USER_LIST, | |
114 | {"ignored_users": "unexpected"}, | |
115 | ) | |
116 | ) | |
117 | ||
118 | # No one ignores the user now. | |
119 | self.assert_ignorers("@other:test", set()) |
0 | # -*- coding: utf-8 -*- | |
1 | # Copyright 2020 The Matrix.org Foundation C.I.C. | |
2 | # | |
3 | # Licensed under the Apache License, Version 2.0 (the 'License'); | |
4 | # you may not use this file except in compliance with the License. | |
5 | # You may obtain a copy of the License at | |
6 | # | |
7 | # http://www.apache.org/licenses/LICENSE-2.0 | |
8 | # | |
9 | # Unless required by applicable law or agreed to in writing, software | |
10 | # distributed under the License is distributed on an 'AS IS' BASIS, | |
11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |
12 | # See the License for the specific language governing permissions and | |
13 | # limitations under the License. | |
14 | ||
15 | from typing import Dict, List, Set, Tuple | |
16 | ||
17 | from twisted.trial import unittest | |
18 | ||
19 | from synapse.api.constants import EventTypes | |
20 | from synapse.api.room_versions import RoomVersions | |
21 | from synapse.events import EventBase | |
22 | from synapse.rest import admin | |
23 | from synapse.rest.client.v1 import login, room | |
24 | from synapse.storage.databases.main.events import _LinkMap | |
25 | from synapse.types import create_requester | |
26 | ||
27 | from tests.unittest import HomeserverTestCase | |
28 | ||
29 | ||
30 | class EventChainStoreTestCase(HomeserverTestCase): | |
31 | def prepare(self, reactor, clock, hs): | |
32 | self.store = hs.get_datastore() | |
33 | self._next_stream_ordering = 1 | |
34 | ||
35 | def test_simple(self): | |
36 | """Test that the example in `docs/auth_chain_difference_algorithm.md` | |
37 | works. | |
38 | """ | |
39 | ||
40 | event_factory = self.hs.get_event_builder_factory() | |
41 | bob = "@creator:test" | |
42 | alice = "@alice:test" | |
43 | room_id = "!room:test" | |
44 | ||
45 | # Ensure that we have a rooms entry so that we generate the chain index. | |
46 | self.get_success( | |
47 | self.store.store_room( | |
48 | room_id=room_id, | |
49 | room_creator_user_id="", | |
50 | is_public=True, | |
51 | room_version=RoomVersions.V6, | |
52 | ) | |
53 | ) | |
54 | ||
55 | create = self.get_success( | |
56 | event_factory.for_room_version( | |
57 | RoomVersions.V6, | |
58 | { | |
59 | "type": EventTypes.Create, | |
60 | "state_key": "", | |
61 | "sender": bob, | |
62 | "room_id": room_id, | |
63 | "content": {"tag": "create"}, | |
64 | }, | |
65 | ).build(prev_event_ids=[], auth_event_ids=[]) | |
66 | ) | |
67 | ||
68 | bob_join = self.get_success( | |
69 | event_factory.for_room_version( | |
70 | RoomVersions.V6, | |
71 | { | |
72 | "type": EventTypes.Member, | |
73 | "state_key": bob, | |
74 | "sender": bob, | |
75 | "room_id": room_id, | |
76 | "content": {"tag": "bob_join"}, | |
77 | }, | |
78 | ).build(prev_event_ids=[], auth_event_ids=[create.event_id]) | |
79 | ) | |
80 | ||
81 | power = self.get_success( | |
82 | event_factory.for_room_version( | |
83 | RoomVersions.V6, | |
84 | { | |
85 | "type": EventTypes.PowerLevels, | |
86 | "state_key": "", | |
87 | "sender": bob, | |
88 | "room_id": room_id, | |
89 | "content": {"tag": "power"}, | |
90 | }, | |
91 | ).build( | |
92 | prev_event_ids=[], auth_event_ids=[create.event_id, bob_join.event_id], | |
93 | ) | |
94 | ) | |
95 | ||
96 | alice_invite = self.get_success( | |
97 | event_factory.for_room_version( | |
98 | RoomVersions.V6, | |
99 | { | |
100 | "type": EventTypes.Member, | |
101 | "state_key": alice, | |
102 | "sender": bob, | |
103 | "room_id": room_id, | |
104 | "content": {"tag": "alice_invite"}, | |
105 | }, | |
106 | ).build( | |
107 | prev_event_ids=[], | |
108 | auth_event_ids=[create.event_id, bob_join.event_id, power.event_id], | |
109 | ) | |
110 | ) | |
111 | ||
112 | alice_join = self.get_success( | |
113 | event_factory.for_room_version( | |
114 | RoomVersions.V6, | |
115 | { | |
116 | "type": EventTypes.Member, | |
117 | "state_key": alice, | |
118 | "sender": alice, | |
119 | "room_id": room_id, | |
120 | "content": {"tag": "alice_join"}, | |
121 | }, | |
122 | ).build( | |
123 | prev_event_ids=[], | |
124 | auth_event_ids=[create.event_id, alice_invite.event_id, power.event_id], | |
125 | ) | |
126 | ) | |
127 | ||
128 | power_2 = self.get_success( | |
129 | event_factory.for_room_version( | |
130 | RoomVersions.V6, | |
131 | { | |
132 | "type": EventTypes.PowerLevels, | |
133 | "state_key": "", | |
134 | "sender": bob, | |
135 | "room_id": room_id, | |
136 | "content": {"tag": "power_2"}, | |
137 | }, | |
138 | ).build( | |
139 | prev_event_ids=[], | |
140 | auth_event_ids=[create.event_id, bob_join.event_id, power.event_id], | |
141 | ) | |
142 | ) | |
143 | ||
144 | bob_join_2 = self.get_success( | |
145 | event_factory.for_room_version( | |
146 | RoomVersions.V6, | |
147 | { | |
148 | "type": EventTypes.Member, | |
149 | "state_key": bob, | |
150 | "sender": bob, | |
151 | "room_id": room_id, | |
152 | "content": {"tag": "bob_join_2"}, | |
153 | }, | |
154 | ).build( | |
155 | prev_event_ids=[], | |
156 | auth_event_ids=[create.event_id, bob_join.event_id, power.event_id], | |
157 | ) | |
158 | ) | |
159 | ||
160 | alice_join2 = self.get_success( | |
161 | event_factory.for_room_version( | |
162 | RoomVersions.V6, | |
163 | { | |
164 | "type": EventTypes.Member, | |
165 | "state_key": alice, | |
166 | "sender": alice, | |
167 | "room_id": room_id, | |
168 | "content": {"tag": "alice_join2"}, | |
169 | }, | |
170 | ).build( | |
171 | prev_event_ids=[], | |
172 | auth_event_ids=[ | |
173 | create.event_id, | |
174 | alice_join.event_id, | |
175 | power_2.event_id, | |
176 | ], | |
177 | ) | |
178 | ) | |
179 | ||
180 | events = [ | |
181 | create, | |
182 | bob_join, | |
183 | power, | |
184 | alice_invite, | |
185 | alice_join, | |
186 | bob_join_2, | |
187 | power_2, | |
188 | alice_join2, | |
189 | ] | |
190 | ||
191 | expected_links = [ | |
192 | (bob_join, create), | |
193 | (power, create), | |
194 | (power, bob_join), | |
195 | (alice_invite, create), | |
196 | (alice_invite, power), | |
197 | (alice_invite, bob_join), | |
198 | (bob_join_2, power), | |
199 | (alice_join2, power_2), | |
200 | ] | |
201 | ||
202 | self.persist(events) | |
203 | chain_map, link_map = self.fetch_chains(events) | |
204 | ||
205 | # Check that the expected links and only the expected links have been | |
206 | # added. | |
207 | self.assertEqual(len(expected_links), len(list(link_map.get_additions()))) | |
208 | ||
209 | for start, end in expected_links: | |
210 | start_id, start_seq = chain_map[start.event_id] | |
211 | end_id, end_seq = chain_map[end.event_id] | |
212 | ||
213 | self.assertIn( | |
214 | (start_seq, end_seq), list(link_map.get_links_between(start_id, end_id)) | |
215 | ) | |
216 | ||
217 | # Test that everything can reach the create event, but the create event | |
218 | # can't reach anything. | |
219 | for event in events[1:]: | |
220 | self.assertTrue( | |
221 | link_map.exists_path_from( | |
222 | chain_map[event.event_id], chain_map[create.event_id] | |
223 | ), | |
224 | ) | |
225 | ||
226 | self.assertFalse( | |
227 | link_map.exists_path_from( | |
228 | chain_map[create.event_id], chain_map[event.event_id], | |
229 | ), | |
230 | ) | |
231 | ||
232 | def test_out_of_order_events(self): | |
233 | """Test that we handle persisting events that we don't have the full | |
234 | auth chain for yet (which should only happen for out of band memberships). | |
235 | """ | |
236 | event_factory = self.hs.get_event_builder_factory() | |
237 | bob = "@creator:test" | |
238 | alice = "@alice:test" | |
239 | room_id = "!room:test" | |
240 | ||
241 | # Ensure that we have a rooms entry so that we generate the chain index. | |
242 | self.get_success( | |
243 | self.store.store_room( | |
244 | room_id=room_id, | |
245 | room_creator_user_id="", | |
246 | is_public=True, | |
247 | room_version=RoomVersions.V6, | |
248 | ) | |
249 | ) | |
250 | ||
251 | # First persist the base room. | |
252 | create = self.get_success( | |
253 | event_factory.for_room_version( | |
254 | RoomVersions.V6, | |
255 | { | |
256 | "type": EventTypes.Create, | |
257 | "state_key": "", | |
258 | "sender": bob, | |
259 | "room_id": room_id, | |
260 | "content": {"tag": "create"}, | |
261 | }, | |
262 | ).build(prev_event_ids=[], auth_event_ids=[]) | |
263 | ) | |
264 | ||
265 | bob_join = self.get_success( | |
266 | event_factory.for_room_version( | |
267 | RoomVersions.V6, | |
268 | { | |
269 | "type": EventTypes.Member, | |
270 | "state_key": bob, | |
271 | "sender": bob, | |
272 | "room_id": room_id, | |
273 | "content": {"tag": "bob_join"}, | |
274 | }, | |
275 | ).build(prev_event_ids=[], auth_event_ids=[create.event_id]) | |
276 | ) | |
277 | ||
278 | power = self.get_success( | |
279 | event_factory.for_room_version( | |
280 | RoomVersions.V6, | |
281 | { | |
282 | "type": EventTypes.PowerLevels, | |
283 | "state_key": "", | |
284 | "sender": bob, | |
285 | "room_id": room_id, | |
286 | "content": {"tag": "power"}, | |
287 | }, | |
288 | ).build( | |
289 | prev_event_ids=[], auth_event_ids=[create.event_id, bob_join.event_id], | |
290 | ) | |
291 | ) | |
292 | ||
293 | self.persist([create, bob_join, power]) | |
294 | ||
295 | # Now persist an invite and a couple of memberships out of order. | |
296 | alice_invite = self.get_success( | |
297 | event_factory.for_room_version( | |
298 | RoomVersions.V6, | |
299 | { | |
300 | "type": EventTypes.Member, | |
301 | "state_key": alice, | |
302 | "sender": bob, | |
303 | "room_id": room_id, | |
304 | "content": {"tag": "alice_invite"}, | |
305 | }, | |
306 | ).build( | |
307 | prev_event_ids=[], | |
308 | auth_event_ids=[create.event_id, bob_join.event_id, power.event_id], | |
309 | ) | |
310 | ) | |
311 | ||
312 | alice_join = self.get_success( | |
313 | event_factory.for_room_version( | |
314 | RoomVersions.V6, | |
315 | { | |
316 | "type": EventTypes.Member, | |
317 | "state_key": alice, | |
318 | "sender": alice, | |
319 | "room_id": room_id, | |
320 | "content": {"tag": "alice_join"}, | |
321 | }, | |
322 | ).build( | |
323 | prev_event_ids=[], | |
324 | auth_event_ids=[create.event_id, alice_invite.event_id, power.event_id], | |
325 | ) | |
326 | ) | |
327 | ||
328 | alice_join2 = self.get_success( | |
329 | event_factory.for_room_version( | |
330 | RoomVersions.V6, | |
331 | { | |
332 | "type": EventTypes.Member, | |
333 | "state_key": alice, | |
334 | "sender": alice, | |
335 | "room_id": room_id, | |
336 | "content": {"tag": "alice_join2"}, | |
337 | }, | |
338 | ).build( | |
339 | prev_event_ids=[], | |
340 | auth_event_ids=[create.event_id, alice_join.event_id, power.event_id], | |
341 | ) | |
342 | ) | |
343 | ||
344 | self.persist([alice_join]) | |
345 | self.persist([alice_join2]) | |
346 | self.persist([alice_invite]) | |
347 | ||
348 | # The end result should be sane. | |
349 | events = [create, bob_join, power, alice_invite, alice_join] | |
350 | ||
351 | chain_map, link_map = self.fetch_chains(events) | |
352 | ||
353 | expected_links = [ | |
354 | (bob_join, create), | |
355 | (power, create), | |
356 | (power, bob_join), | |
357 | (alice_invite, create), | |
358 | (alice_invite, power), | |
359 | (alice_invite, bob_join), | |
360 | ] | |
361 | ||
362 | # Check that the expected links and only the expected links have been | |
363 | # added. | |
364 | self.assertEqual(len(expected_links), len(list(link_map.get_additions()))) | |
365 | ||
366 | for start, end in expected_links: | |
367 | start_id, start_seq = chain_map[start.event_id] | |
368 | end_id, end_seq = chain_map[end.event_id] | |
369 | ||
370 | self.assertIn( | |
371 | (start_seq, end_seq), list(link_map.get_links_between(start_id, end_id)) | |
372 | ) | |
373 | ||
374 | def persist( | |
375 | self, events: List[EventBase], | |
376 | ): | |
377 | """Persist the given events and check that the links generated match | |
378 | those given. | |
379 | """ | |
380 | ||
381 | persist_events_store = self.hs.get_datastores().persist_events | |
382 | ||
383 | for e in events: | |
384 | e.internal_metadata.stream_ordering = self._next_stream_ordering | |
385 | self._next_stream_ordering += 1 | |
386 | ||
387 | def _persist(txn): | |
388 | # We need to persist the events to the events and state_events | |
389 | # tables. | |
390 | persist_events_store._store_event_txn(txn, [(e, {}) for e in events]) | |
391 | ||
392 | # Actually call the function that calculates the auth chain stuff. | |
393 | persist_events_store._persist_event_auth_chain_txn(txn, events) | |
394 | ||
395 | self.get_success( | |
396 | persist_events_store.db_pool.runInteraction("_persist", _persist,) | |
397 | ) | |
398 | ||
399 | def fetch_chains( | |
400 | self, events: List[EventBase] | |
401 | ) -> Tuple[Dict[str, Tuple[int, int]], _LinkMap]: | |
402 | ||
403 | # Fetch the map from event ID -> (chain ID, sequence number) | |
404 | rows = self.get_success( | |
405 | self.store.db_pool.simple_select_many_batch( | |
406 | table="event_auth_chains", | |
407 | column="event_id", | |
408 | iterable=[e.event_id for e in events], | |
409 | retcols=("event_id", "chain_id", "sequence_number"), | |
410 | keyvalues={}, | |
411 | ) | |
412 | ) | |
413 | ||
414 | chain_map = { | |
415 | row["event_id"]: (row["chain_id"], row["sequence_number"]) for row in rows | |
416 | } | |
417 | ||
418 | # Fetch all the links and pass them to the _LinkMap. | |
419 | rows = self.get_success( | |
420 | self.store.db_pool.simple_select_many_batch( | |
421 | table="event_auth_chain_links", | |
422 | column="origin_chain_id", | |
423 | iterable=[chain_id for chain_id, _ in chain_map.values()], | |
424 | retcols=( | |
425 | "origin_chain_id", | |
426 | "origin_sequence_number", | |
427 | "target_chain_id", | |
428 | "target_sequence_number", | |
429 | ), | |
430 | keyvalues={}, | |
431 | ) | |
432 | ) | |
433 | ||
434 | link_map = _LinkMap() | |
435 | for row in rows: | |
436 | added = link_map.add_link( | |
437 | (row["origin_chain_id"], row["origin_sequence_number"]), | |
438 | (row["target_chain_id"], row["target_sequence_number"]), | |
439 | ) | |
440 | ||
441 | # We shouldn't have persisted any redundant links | |
442 | self.assertTrue(added) | |
443 | ||
444 | return chain_map, link_map | |
445 | ||
446 | ||
447 | class LinkMapTestCase(unittest.TestCase): | |
448 | def test_simple(self): | |
449 | """Basic tests for the LinkMap. | |
450 | """ | |
451 | link_map = _LinkMap() | |
452 | ||
453 | link_map.add_link((1, 1), (2, 1), new=False) | |
454 | self.assertCountEqual(link_map.get_links_between(1, 2), [(1, 1)]) | |
455 | self.assertCountEqual(link_map.get_links_from((1, 1)), [(2, 1)]) | |
456 | self.assertCountEqual(link_map.get_additions(), []) | |
457 | self.assertTrue(link_map.exists_path_from((1, 5), (2, 1))) | |
458 | self.assertFalse(link_map.exists_path_from((1, 5), (2, 2))) | |
459 | self.assertTrue(link_map.exists_path_from((1, 5), (1, 1))) | |
460 | self.assertFalse(link_map.exists_path_from((1, 1), (1, 5))) | |
461 | ||
462 | # Attempting to add a redundant link is ignored. | |
463 | self.assertFalse(link_map.add_link((1, 4), (2, 1))) | |
464 | self.assertCountEqual(link_map.get_links_between(1, 2), [(1, 1)]) | |
465 | ||
466 | # Adding new non-redundant links works | |
467 | self.assertTrue(link_map.add_link((1, 3), (2, 3))) | |
468 | self.assertCountEqual(link_map.get_links_between(1, 2), [(1, 1), (3, 3)]) | |
469 | ||
470 | self.assertTrue(link_map.add_link((2, 5), (1, 3))) | |
471 | self.assertCountEqual(link_map.get_links_between(2, 1), [(5, 3)]) | |
472 | self.assertCountEqual(link_map.get_links_between(1, 2), [(1, 1), (3, 3)]) | |
473 | ||
474 | self.assertCountEqual(link_map.get_additions(), [(1, 3, 2, 3), (2, 5, 1, 3)]) | |
475 | ||
476 | ||
477 | class EventChainBackgroundUpdateTestCase(HomeserverTestCase): | |
478 | ||
479 | servlets = [ | |
480 | admin.register_servlets, | |
481 | room.register_servlets, | |
482 | login.register_servlets, | |
483 | ] | |
484 | ||
485 | def prepare(self, reactor, clock, hs): | |
486 | self.store = hs.get_datastore() | |
487 | self.user_id = self.register_user("foo", "pass") | |
488 | self.token = self.login("foo", "pass") | |
489 | self.requester = create_requester(self.user_id) | |
490 | ||
491 | def _generate_room(self) -> Tuple[str, List[Set[str]]]: | |
492 | """Insert a room without a chain cover index. | |
493 | """ | |
494 | room_id = self.helper.create_room_as(self.user_id, tok=self.token) | |
495 | ||
496 | # Mark the room as not having a chain cover index | |
497 | self.get_success( | |
498 | self.store.db_pool.simple_update( | |
499 | table="rooms", | |
500 | keyvalues={"room_id": room_id}, | |
501 | updatevalues={"has_auth_chain_index": False}, | |
502 | desc="test", | |
503 | ) | |
504 | ) | |
505 | ||
506 | # Create a fork in the DAG with different events. | |
507 | event_handler = self.hs.get_event_creation_handler() | |
508 | latest_event_ids = self.get_success( | |
509 | self.store.get_prev_events_for_room(room_id) | |
510 | ) | |
511 | event, context = self.get_success( | |
512 | event_handler.create_event( | |
513 | self.requester, | |
514 | { | |
515 | "type": "some_state_type", | |
516 | "state_key": "", | |
517 | "content": {}, | |
518 | "room_id": room_id, | |
519 | "sender": self.user_id, | |
520 | }, | |
521 | prev_event_ids=latest_event_ids, | |
522 | ) | |
523 | ) | |
524 | self.get_success( | |
525 | event_handler.handle_new_client_event(self.requester, event, context) | |
526 | ) | |
527 | state1 = set(self.get_success(context.get_current_state_ids()).values()) | |
528 | ||
529 | event, context = self.get_success( | |
530 | event_handler.create_event( | |
531 | self.requester, | |
532 | { | |
533 | "type": "some_state_type", | |
534 | "state_key": "", | |
535 | "content": {}, | |
536 | "room_id": room_id, | |
537 | "sender": self.user_id, | |
538 | }, | |
539 | prev_event_ids=latest_event_ids, | |
540 | ) | |
541 | ) | |
542 | self.get_success( | |
543 | event_handler.handle_new_client_event(self.requester, event, context) | |
544 | ) | |
545 | state2 = set(self.get_success(context.get_current_state_ids()).values()) | |
546 | ||
547 | # Delete the chain cover info. | |
548 | ||
549 | def _delete_tables(txn): | |
550 | txn.execute("DELETE FROM event_auth_chains") | |
551 | txn.execute("DELETE FROM event_auth_chain_links") | |
552 | ||
553 | self.get_success(self.store.db_pool.runInteraction("test", _delete_tables)) | |
554 | ||
555 | return room_id, [state1, state2] | |
556 | ||
557 | def test_background_update_single_room(self): | |
558 | """Test that the background update to calculate auth chains for historic | |
559 | rooms works correctly. | |
560 | """ | |
561 | ||
562 | # Create a room | |
563 | room_id, states = self._generate_room() | |
564 | ||
565 | # Insert and run the background update. | |
566 | self.get_success( | |
567 | self.store.db_pool.simple_insert( | |
568 | "background_updates", | |
569 | {"update_name": "chain_cover", "progress_json": "{}"}, | |
570 | ) | |
571 | ) | |
572 | ||
573 | # Ugh, have to reset this flag | |
574 | self.store.db_pool.updates._all_done = False | |
575 | ||
576 | while not self.get_success( | |
577 | self.store.db_pool.updates.has_completed_background_updates() | |
578 | ): | |
579 | self.get_success( | |
580 | self.store.db_pool.updates.do_next_background_update(100), by=0.1 | |
581 | ) | |
582 | ||
583 | # Test that the `has_auth_chain_index` has been set | |
584 | self.assertTrue(self.get_success(self.store.has_auth_chain_index(room_id))) | |
585 | ||
586 | # Test that calculating the auth chain difference using the newly | |
587 | # calculated chain cover works. | |
588 | self.get_success( | |
589 | self.store.db_pool.runInteraction( | |
590 | "test", | |
591 | self.store._get_auth_chain_difference_using_cover_index_txn, | |
592 | room_id, | |
593 | states, | |
594 | ) | |
595 | ) | |
596 | ||
597 | def test_background_update_multiple_rooms(self): | |
598 | """Test that the background update to calculate auth chains for historic | |
599 | rooms works correctly. | |
600 | """ | |
601 | # Create a room | |
602 | room_id1, states1 = self._generate_room() | |
603 | room_id2, states2 = self._generate_room() | |
604 | room_id3, states2 = self._generate_room() | |
605 | ||
606 | # Insert and run the background update. | |
607 | self.get_success( | |
608 | self.store.db_pool.simple_insert( | |
609 | "background_updates", | |
610 | {"update_name": "chain_cover", "progress_json": "{}"}, | |
611 | ) | |
612 | ) | |
613 | ||
614 | # Ugh, have to reset this flag | |
615 | self.store.db_pool.updates._all_done = False | |
616 | ||
617 | while not self.get_success( | |
618 | self.store.db_pool.updates.has_completed_background_updates() | |
619 | ): | |
620 | self.get_success( | |
621 | self.store.db_pool.updates.do_next_background_update(100), by=0.1 | |
622 | ) | |
623 | ||
624 | # Test that the `has_auth_chain_index` has been set | |
625 | self.assertTrue(self.get_success(self.store.has_auth_chain_index(room_id1))) | |
626 | self.assertTrue(self.get_success(self.store.has_auth_chain_index(room_id2))) | |
627 | self.assertTrue(self.get_success(self.store.has_auth_chain_index(room_id3))) | |
628 | ||
629 | # Test that calculating the auth chain difference using the newly | |
630 | # calculated chain cover works. | |
631 | self.get_success( | |
632 | self.store.db_pool.runInteraction( | |
633 | "test", | |
634 | self.store._get_auth_chain_difference_using_cover_index_txn, | |
635 | room_id1, | |
636 | states1, | |
637 | ) | |
638 | ) | |
639 | ||
640 | def test_background_update_single_large_room(self): | |
641 | """Test that the background update to calculate auth chains for historic | |
642 | rooms works correctly. | |
643 | """ | |
644 | ||
645 | # Create a room | |
646 | room_id, states = self._generate_room() | |
647 | ||
648 | # Add a bunch of state so that it takes multiple iterations of the | |
649 | # background update to process the room. | |
650 | for i in range(0, 150): | |
651 | self.helper.send_state( | |
652 | room_id, event_type="m.test", body={"index": i}, tok=self.token | |
653 | ) | |
654 | ||
655 | # Insert and run the background update. | |
656 | self.get_success( | |
657 | self.store.db_pool.simple_insert( | |
658 | "background_updates", | |
659 | {"update_name": "chain_cover", "progress_json": "{}"}, | |
660 | ) | |
661 | ) | |
662 | ||
663 | # Ugh, have to reset this flag | |
664 | self.store.db_pool.updates._all_done = False | |
665 | ||
666 | iterations = 0 | |
667 | while not self.get_success( | |
668 | self.store.db_pool.updates.has_completed_background_updates() | |
669 | ): | |
670 | iterations += 1 | |
671 | self.get_success( | |
672 | self.store.db_pool.updates.do_next_background_update(100), by=0.1 | |
673 | ) | |
674 | ||
675 | # Ensure that we did actually take multiple iterations to process the | |
676 | # room. | |
677 | self.assertGreater(iterations, 1) | |
678 | ||
679 | # Test that the `has_auth_chain_index` has been set | |
680 | self.assertTrue(self.get_success(self.store.has_auth_chain_index(room_id))) | |
681 | ||
682 | # Test that calculating the auth chain difference using the newly | |
683 | # calculated chain cover works. | |
684 | self.get_success( | |
685 | self.store.db_pool.runInteraction( | |
686 | "test", | |
687 | self.store._get_auth_chain_difference_using_cover_index_txn, | |
688 | room_id, | |
689 | states, | |
690 | ) | |
691 | ) | |
692 | ||
693 | def test_background_update_multiple_large_room(self): | |
694 | """Test that the background update to calculate auth chains for historic | |
695 | rooms works correctly. | |
696 | """ | |
697 | ||
698 | # Create the rooms | |
699 | room_id1, _ = self._generate_room() | |
700 | room_id2, _ = self._generate_room() | |
701 | ||
702 | # Add a bunch of state so that it takes multiple iterations of the | |
703 | # background update to process the room. | |
704 | for i in range(0, 150): | |
705 | self.helper.send_state( | |
706 | room_id1, event_type="m.test", body={"index": i}, tok=self.token | |
707 | ) | |
708 | ||
709 | for i in range(0, 150): | |
710 | self.helper.send_state( | |
711 | room_id2, event_type="m.test", body={"index": i}, tok=self.token | |
712 | ) | |
713 | ||
714 | # Insert and run the background update. | |
715 | self.get_success( | |
716 | self.store.db_pool.simple_insert( | |
717 | "background_updates", | |
718 | {"update_name": "chain_cover", "progress_json": "{}"}, | |
719 | ) | |
720 | ) | |
721 | ||
722 | # Ugh, have to reset this flag | |
723 | self.store.db_pool.updates._all_done = False | |
724 | ||
725 | iterations = 0 | |
726 | while not self.get_success( | |
727 | self.store.db_pool.updates.has_completed_background_updates() | |
728 | ): | |
729 | iterations += 1 | |
730 | self.get_success( | |
731 | self.store.db_pool.updates.do_next_background_update(100), by=0.1 | |
732 | ) | |
733 | ||
734 | # Ensure that we did actually take multiple iterations to process the | |
735 | # room. | |
736 | self.assertGreater(iterations, 1) | |
737 | ||
738 | # Test that the `has_auth_chain_index` has been set | |
739 | self.assertTrue(self.get_success(self.store.has_auth_chain_index(room_id1))) | |
740 | self.assertTrue(self.get_success(self.store.has_auth_chain_index(room_id2))) |
12 | 12 | # See the License for the specific language governing permissions and |
13 | 13 | # limitations under the License. |
14 | 14 | |
15 | import attr | |
16 | from parameterized import parameterized | |
17 | ||
18 | from synapse.events import _EventInternalMetadata | |
19 | ||
15 | 20 | import tests.unittest |
16 | 21 | import tests.utils |
17 | 22 | |
112 | 117 | r = self.get_success(self.store.get_rooms_with_many_extremities(5, 1, [room1])) |
113 | 118 | self.assertTrue(r == [room2] or r == [room3]) |
114 | 119 | |
115 | def test_auth_difference(self): | |
120 | @parameterized.expand([(True,), (False,)]) | |
121 | def test_auth_difference(self, use_chain_cover_index: bool): | |
116 | 122 | room_id = "@ROOM:local" |
117 | 123 | |
118 | 124 | # The silly auth graph we use to test the auth difference algorithm, |
158 | 164 | "j": 1, |
159 | 165 | } |
160 | 166 | |
167 | # Mark the room as not having a cover index | |
168 | ||
169 | def store_room(txn): | |
170 | self.store.db_pool.simple_insert_txn( | |
171 | txn, | |
172 | "rooms", | |
173 | { | |
174 | "room_id": room_id, | |
175 | "creator": "room_creator_user_id", | |
176 | "is_public": True, | |
177 | "room_version": "6", | |
178 | "has_auth_chain_index": use_chain_cover_index, | |
179 | }, | |
180 | ) | |
181 | ||
182 | self.get_success(self.store.db_pool.runInteraction("store_room", store_room)) | |
183 | ||
161 | 184 | # We rudely fiddle with the appropriate tables directly, as that's much |
162 | 185 | # easier than constructing events properly. |
163 | 186 | |
164 | def insert_event(txn, event_id, stream_ordering): | |
165 | ||
166 | depth = depth_map[event_id] | |
167 | ||
187 | def insert_event(txn): | |
188 | stream_ordering = 0 | |
189 | ||
190 | for event_id in auth_graph: | |
191 | stream_ordering += 1 | |
192 | depth = depth_map[event_id] | |
193 | ||
194 | self.store.db_pool.simple_insert_txn( | |
195 | txn, | |
196 | table="events", | |
197 | values={ | |
198 | "event_id": event_id, | |
199 | "room_id": room_id, | |
200 | "depth": depth, | |
201 | "topological_ordering": depth, | |
202 | "type": "m.test", | |
203 | "processed": True, | |
204 | "outlier": False, | |
205 | "stream_ordering": stream_ordering, | |
206 | }, | |
207 | ) | |
208 | ||
209 | self.hs.datastores.persist_events._persist_event_auth_chain_txn( | |
210 | txn, | |
211 | [ | |
212 | FakeEvent(event_id, room_id, auth_graph[event_id]) | |
213 | for event_id in auth_graph | |
214 | ], | |
215 | ) | |
216 | ||
217 | self.get_success(self.store.db_pool.runInteraction("insert", insert_event,)) | |
218 | ||
219 | # Now actually test that various combinations give the right result: | |
220 | ||
221 | difference = self.get_success( | |
222 | self.store.get_auth_chain_difference(room_id, [{"a"}, {"b"}]) | |
223 | ) | |
224 | self.assertSetEqual(difference, {"a", "b"}) | |
225 | ||
226 | difference = self.get_success( | |
227 | self.store.get_auth_chain_difference(room_id, [{"a"}, {"b"}, {"c"}]) | |
228 | ) | |
229 | self.assertSetEqual(difference, {"a", "b", "c", "e", "f"}) | |
230 | ||
231 | difference = self.get_success( | |
232 | self.store.get_auth_chain_difference(room_id, [{"a", "c"}, {"b"}]) | |
233 | ) | |
234 | self.assertSetEqual(difference, {"a", "b", "c"}) | |
235 | ||
236 | difference = self.get_success( | |
237 | self.store.get_auth_chain_difference(room_id, [{"a", "c"}, {"b", "c"}]) | |
238 | ) | |
239 | self.assertSetEqual(difference, {"a", "b"}) | |
240 | ||
241 | difference = self.get_success( | |
242 | self.store.get_auth_chain_difference(room_id, [{"a"}, {"b"}, {"d"}]) | |
243 | ) | |
244 | self.assertSetEqual(difference, {"a", "b", "d", "e"}) | |
245 | ||
246 | difference = self.get_success( | |
247 | self.store.get_auth_chain_difference(room_id, [{"a"}, {"b"}, {"c"}, {"d"}]) | |
248 | ) | |
249 | self.assertSetEqual(difference, {"a", "b", "c", "d", "e", "f"}) | |
250 | ||
251 | difference = self.get_success( | |
252 | self.store.get_auth_chain_difference(room_id, [{"a"}, {"b"}, {"e"}]) | |
253 | ) | |
254 | self.assertSetEqual(difference, {"a", "b"}) | |
255 | ||
256 | difference = self.get_success( | |
257 | self.store.get_auth_chain_difference(room_id, [{"a"}]) | |
258 | ) | |
259 | self.assertSetEqual(difference, set()) | |
260 | ||
261 | def test_auth_difference_partial_cover(self): | |
262 | """Test that we correctly handle rooms where not all events have a chain | |
263 | cover calculated. This can happen in some obscure edge cases, including | |
264 | during the background update that calculates the chain cover for old | |
265 | rooms. | |
266 | """ | |
267 | ||
268 | room_id = "@ROOM:local" | |
269 | ||
270 | # The silly auth graph we use to test the auth difference algorithm, | |
271 | # where the top are the most recent events. | |
272 | # | |
273 | # A B | |
274 | # \ / | |
275 | # D E | |
276 | # \ | | |
277 | # ` F C | |
278 | # | /| | |
279 | # G ´ | | |
280 | # | \ | | |
281 | # H I | |
282 | # | | | |
283 | # K J | |
284 | ||
285 | auth_graph = { | |
286 | "a": ["e"], | |
287 | "b": ["e"], | |
288 | "c": ["g", "i"], | |
289 | "d": ["f"], | |
290 | "e": ["f"], | |
291 | "f": ["g"], | |
292 | "g": ["h", "i"], | |
293 | "h": ["k"], | |
294 | "i": ["j"], | |
295 | "k": [], | |
296 | "j": [], | |
297 | } | |
298 | ||
299 | depth_map = { | |
300 | "a": 7, | |
301 | "b": 7, | |
302 | "c": 4, | |
303 | "d": 6, | |
304 | "e": 6, | |
305 | "f": 5, | |
306 | "g": 3, | |
307 | "h": 2, | |
308 | "i": 2, | |
309 | "k": 1, | |
310 | "j": 1, | |
311 | } | |
312 | ||
313 | # We rudely fiddle with the appropriate tables directly, as that's much | |
314 | # easier than constructing events properly. | |
315 | ||
316 | def insert_event(txn): | |
317 | # First insert the room and mark it as having a chain cover. | |
168 | 318 | self.store.db_pool.simple_insert_txn( |
169 | 319 | txn, |
170 | table="events", | |
171 | values={ | |
172 | "event_id": event_id, | |
320 | "rooms", | |
321 | { | |
173 | 322 | "room_id": room_id, |
174 | "depth": depth, | |
175 | "topological_ordering": depth, | |
176 | "type": "m.test", | |
177 | "processed": True, | |
178 | "outlier": False, | |
179 | "stream_ordering": stream_ordering, | |
323 | "creator": "room_creator_user_id", | |
324 | "is_public": True, | |
325 | "room_version": "6", | |
326 | "has_auth_chain_index": True, | |
180 | 327 | }, |
181 | 328 | ) |
182 | 329 | |
183 | self.store.db_pool.simple_insert_many_txn( | |
184 | txn, | |
185 | table="event_auth", | |
186 | values=[ | |
187 | {"event_id": event_id, "room_id": room_id, "auth_id": a} | |
188 | for a in auth_graph[event_id] | |
330 | stream_ordering = 0 | |
331 | ||
332 | for event_id in auth_graph: | |
333 | stream_ordering += 1 | |
334 | depth = depth_map[event_id] | |
335 | ||
336 | self.store.db_pool.simple_insert_txn( | |
337 | txn, | |
338 | table="events", | |
339 | values={ | |
340 | "event_id": event_id, | |
341 | "room_id": room_id, | |
342 | "depth": depth, | |
343 | "topological_ordering": depth, | |
344 | "type": "m.test", | |
345 | "processed": True, | |
346 | "outlier": False, | |
347 | "stream_ordering": stream_ordering, | |
348 | }, | |
349 | ) | |
350 | ||
351 | # Insert all events apart from 'B' | |
352 | self.hs.datastores.persist_events._persist_event_auth_chain_txn( | |
353 | txn, | |
354 | [ | |
355 | FakeEvent(event_id, room_id, auth_graph[event_id]) | |
356 | for event_id in auth_graph | |
357 | if event_id != "b" | |
189 | 358 | ], |
190 | 359 | ) |
191 | 360 | |
192 | next_stream_ordering = 0 | |
193 | for event_id in auth_graph: | |
194 | next_stream_ordering += 1 | |
195 | self.get_success( | |
196 | self.store.db_pool.runInteraction( | |
197 | "insert", insert_event, event_id, next_stream_ordering | |
198 | ) | |
199 | ) | |
361 | # Now we insert the event 'B' without a chain cover, by temporarily | |
362 | # pretending the room doesn't have a chain cover. | |
363 | ||
364 | self.store.db_pool.simple_update_txn( | |
365 | txn, | |
366 | table="rooms", | |
367 | keyvalues={"room_id": room_id}, | |
368 | updatevalues={"has_auth_chain_index": False}, | |
369 | ) | |
370 | ||
371 | self.hs.datastores.persist_events._persist_event_auth_chain_txn( | |
372 | txn, [FakeEvent("b", room_id, auth_graph["b"])], | |
373 | ) | |
374 | ||
375 | self.store.db_pool.simple_update_txn( | |
376 | txn, | |
377 | table="rooms", | |
378 | keyvalues={"room_id": room_id}, | |
379 | updatevalues={"has_auth_chain_index": True}, | |
380 | ) | |
381 | ||
382 | self.get_success(self.store.db_pool.runInteraction("insert", insert_event,)) | |
200 | 383 | |
201 | 384 | # Now actually test that various combinations give the right result: |
202 | 385 | |
239 | 422 | self.store.get_auth_chain_difference(room_id, [{"a"}]) |
240 | 423 | ) |
241 | 424 | self.assertSetEqual(difference, set()) |
425 | ||
426 | ||
427 | @attr.s | |
428 | class FakeEvent: | |
429 | event_id = attr.ib() | |
430 | room_id = attr.ib() | |
431 | auth_events = attr.ib() | |
432 | ||
433 | type = "foo" | |
434 | state_key = "foo" | |
435 | ||
436 | internal_metadata = _EventInternalMetadata({}) | |
437 | ||
438 | def auth_event_ids(self): | |
439 | return self.auth_events | |
440 | ||
441 | def is_state(self): | |
442 | return True |
50 | 50 | self.db_pool, |
51 | 51 | stream_name="test_stream", |
52 | 52 | instance_name=instance_name, |
53 | table="foobar", | |
54 | instance_column="instance_name", | |
55 | id_column="stream_id", | |
53 | tables=[("foobar", "instance_name", "stream_id")], | |
56 | 54 | sequence_name="foobar_seq", |
57 | 55 | writers=writers, |
58 | 56 | ) |
486 | 484 | self.db_pool, |
487 | 485 | stream_name="test_stream", |
488 | 486 | instance_name=instance_name, |
489 | table="foobar", | |
490 | instance_column="instance_name", | |
491 | id_column="stream_id", | |
487 | tables=[("foobar", "instance_name", "stream_id")], | |
492 | 488 | sequence_name="foobar_seq", |
493 | 489 | writers=writers, |
494 | 490 | positive=False, |
578 | 574 | self.assertEqual(id_gen_2.get_positions(), {"first": -1, "second": -2}) |
579 | 575 | self.assertEqual(id_gen_1.get_persisted_upto_position(), -2) |
580 | 576 | self.assertEqual(id_gen_2.get_persisted_upto_position(), -2) |
577 | ||
578 | ||
579 | class MultiTableMultiWriterIdGeneratorTestCase(HomeserverTestCase): | |
580 | if not USE_POSTGRES_FOR_TESTS: | |
581 | skip = "Requires Postgres" | |
582 | ||
583 | def prepare(self, reactor, clock, hs): | |
584 | self.store = hs.get_datastore() | |
585 | self.db_pool = self.store.db_pool # type: DatabasePool | |
586 | ||
587 | self.get_success(self.db_pool.runInteraction("_setup_db", self._setup_db)) | |
588 | ||
589 | def _setup_db(self, txn): | |
590 | txn.execute("CREATE SEQUENCE foobar_seq") | |
591 | txn.execute( | |
592 | """ | |
593 | CREATE TABLE foobar1 ( | |
594 | stream_id BIGINT NOT NULL, | |
595 | instance_name TEXT NOT NULL, | |
596 | data TEXT | |
597 | ); | |
598 | """ | |
599 | ) | |
600 | ||
601 | txn.execute( | |
602 | """ | |
603 | CREATE TABLE foobar2 ( | |
604 | stream_id BIGINT NOT NULL, | |
605 | instance_name TEXT NOT NULL, | |
606 | data TEXT | |
607 | ); | |
608 | """ | |
609 | ) | |
610 | ||
611 | def _create_id_generator( | |
612 | self, instance_name="master", writers=["master"] | |
613 | ) -> MultiWriterIdGenerator: | |
614 | def _create(conn): | |
615 | return MultiWriterIdGenerator( | |
616 | conn, | |
617 | self.db_pool, | |
618 | stream_name="test_stream", | |
619 | instance_name=instance_name, | |
620 | tables=[ | |
621 | ("foobar1", "instance_name", "stream_id"), | |
622 | ("foobar2", "instance_name", "stream_id"), | |
623 | ], | |
624 | sequence_name="foobar_seq", | |
625 | writers=writers, | |
626 | ) | |
627 | ||
628 | return self.get_success_or_raise(self.db_pool.runWithConnection(_create)) | |
629 | ||
630 | def _insert_rows( | |
631 | self, | |
632 | table: str, | |
633 | instance_name: str, | |
634 | number: int, | |
635 | update_stream_table: bool = True, | |
636 | ): | |
637 | """Insert N rows as the given instance, inserting with stream IDs pulled | |
638 | from the postgres sequence. | |
639 | """ | |
640 | ||
641 | def _insert(txn): | |
642 | for _ in range(number): | |
643 | txn.execute( | |
644 | "INSERT INTO %s VALUES (nextval('foobar_seq'), ?)" % (table,), | |
645 | (instance_name,), | |
646 | ) | |
647 | if update_stream_table: | |
648 | txn.execute( | |
649 | """ | |
650 | INSERT INTO stream_positions VALUES ('test_stream', ?, lastval()) | |
651 | ON CONFLICT (stream_name, instance_name) DO UPDATE SET stream_id = lastval() | |
652 | """, | |
653 | (instance_name,), | |
654 | ) | |
655 | ||
656 | self.get_success(self.db_pool.runInteraction("_insert_rows", _insert)) | |
657 | ||
658 | def test_load_existing_stream(self): | |
659 | """Test creating ID gens with multiple tables that have rows from after | |
660 | the position in `stream_positions` table. | |
661 | """ | |
662 | self._insert_rows("foobar1", "first", 3) | |
663 | self._insert_rows("foobar2", "second", 3) | |
664 | self._insert_rows("foobar2", "second", 1, update_stream_table=False) | |
665 | ||
666 | first_id_gen = self._create_id_generator("first", writers=["first", "second"]) | |
667 | second_id_gen = self._create_id_generator("second", writers=["first", "second"]) | |
668 | ||
669 | # The first ID gen will notice that it can advance its token to 7 as it | |
670 | # has no in progress writes... | |
671 | self.assertEqual(first_id_gen.get_positions(), {"first": 7, "second": 6}) | |
672 | self.assertEqual(first_id_gen.get_current_token_for_writer("first"), 7) | |
673 | self.assertEqual(first_id_gen.get_current_token_for_writer("second"), 6) | |
674 | self.assertEqual(first_id_gen.get_persisted_upto_position(), 7) | |
675 | ||
676 | # ... but the second ID gen doesn't know that. | |
677 | self.assertEqual(second_id_gen.get_positions(), {"first": 3, "second": 7}) | |
678 | self.assertEqual(second_id_gen.get_current_token_for_writer("first"), 3) | |
679 | self.assertEqual(second_id_gen.get_current_token_for_writer("second"), 7) | |
680 | self.assertEqual(first_id_gen.get_persisted_upto_position(), 7) |
47 | 47 | ), |
48 | 48 | ) |
49 | 49 | |
50 | # test set to None | |
51 | yield defer.ensureDeferred( | |
52 | self.store.set_profile_displayname(self.u_frank.localpart, None) | |
53 | ) | |
54 | ||
55 | self.assertIsNone( | |
56 | ( | |
57 | yield defer.ensureDeferred( | |
58 | self.store.get_profile_displayname(self.u_frank.localpart) | |
59 | ) | |
60 | ) | |
61 | ) | |
62 | ||
50 | 63 | @defer.inlineCallbacks |
51 | 64 | def test_avatar_url(self): |
52 | 65 | yield defer.ensureDeferred(self.store.create_profile(self.u_frank.localpart)) |
65 | 78 | ) |
66 | 79 | ), |
67 | 80 | ) |
81 | ||
82 | # test set to None | |
83 | yield defer.ensureDeferred( | |
84 | self.store.set_profile_avatar_url(self.u_frank.localpart, None) | |
85 | ) | |
86 | ||
87 | self.assertIsNone( | |
88 | ( | |
89 | yield defer.ensureDeferred( | |
90 | self.store.get_profile_avatar_url(self.u_frank.localpart) | |
91 | ) | |
92 | ) | |
93 | ) |
19 | 19 | |
20 | 20 | from . import unittest |
21 | 21 | |
22 | try: | |
23 | import lxml | |
24 | except ImportError: | |
25 | lxml = None | |
26 | ||
22 | 27 | |
23 | 28 | class PreviewTestCase(unittest.TestCase): |
29 | if not lxml: | |
30 | skip = "url preview feature requires lxml" | |
31 | ||
24 | 32 | def test_long_summarize(self): |
25 | 33 | example_paras = [ |
26 | 34 | """Tromsø (Norwegian pronunciation: [ˈtrʊmsœ] ( listen); Northern Sami: |
136 | 144 | |
137 | 145 | |
138 | 146 | class PreviewUrlTestCase(unittest.TestCase): |
147 | if not lxml: | |
148 | skip = "url preview feature requires lxml" | |
149 | ||
139 | 150 | def test_simple(self): |
140 | 151 | html = """ |
141 | 152 | <html> |
57 | 57 | |
58 | 58 | self.assertEquals(room.to_string(), "#channel:my.domain") |
59 | 59 | |
60 | def test_validate(self): | |
61 | id_string = "#test:domain,test" | |
62 | self.assertFalse(RoomAlias.is_valid(id_string)) | |
63 | ||
60 | 64 | |
61 | 65 | class GroupIDTestCase(unittest.TestCase): |
62 | 66 | def test_parse(self): |
0 | # -*- coding: utf-8 -*- | |
1 | # Copyright 2021 The Matrix.org Foundation C.I.C. | |
2 | # | |
3 | # Licensed under the Apache License, Version 2.0 (the "License"); | |
4 | # you may not use this file except in compliance with the License. | |
5 | # You may obtain a copy of the License at | |
6 | # | |
7 | # http://www.apache.org/licenses/LICENSE-2.0 | |
8 | # | |
9 | # Unless required by applicable law or agreed to in writing, software | |
10 | # distributed under the License is distributed on an "AS IS" BASIS, | |
11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |
12 | # See the License for the specific language governing permissions and | |
13 | # limitations under the License. | |
14 | ||
15 | from html.parser import HTMLParser | |
16 | from typing import Dict, Iterable, List, Optional, Tuple | |
17 | ||
18 | ||
19 | class TestHtmlParser(HTMLParser): | |
20 | """A generic HTML page parser which extracts useful things from the HTML""" | |
21 | ||
22 | def __init__(self): | |
23 | super().__init__() | |
24 | ||
25 | # a list of links found in the doc | |
26 | self.links = [] # type: List[str] | |
27 | ||
28 | # the values of any hidden <input>s: map from name to value | |
29 | self.hiddens = {} # type: Dict[str, Optional[str]] | |
30 | ||
31 | # the values of any radio buttons: map from name to list of values | |
32 | self.radios = {} # type: Dict[str, List[Optional[str]]] | |
33 | ||
34 | def handle_starttag( | |
35 | self, tag: str, attrs: Iterable[Tuple[str, Optional[str]]] | |
36 | ) -> None: | |
37 | attr_dict = dict(attrs) | |
38 | if tag == "a": | |
39 | href = attr_dict["href"] | |
40 | if href: | |
41 | self.links.append(href) | |
42 | elif tag == "input": | |
43 | input_name = attr_dict.get("name") | |
44 | if attr_dict["type"] == "radio": | |
45 | assert input_name | |
46 | self.radios.setdefault(input_name, []).append(attr_dict["value"]) | |
47 | elif attr_dict["type"] == "hidden": | |
48 | assert input_name | |
49 | self.hiddens[input_name] = attr_dict["value"] | |
50 | ||
51 | def error(_, message): | |
52 | raise AssertionError(message) |
19 | 19 | import inspect |
20 | 20 | import logging |
21 | 21 | import time |
22 | from typing import Dict, Iterable, Optional, Tuple, Type, TypeVar, Union | |
22 | from typing import Callable, Dict, Iterable, Optional, Tuple, Type, TypeVar, Union | |
23 | 23 | |
24 | 24 | from mock import Mock, patch |
25 | 25 | |
735 | 735 | return func |
736 | 736 | |
737 | 737 | return decorator |
738 | ||
739 | ||
740 | TV = TypeVar("TV") | |
741 | ||
742 | ||
743 | def skip_unless(condition: bool, reason: str) -> Callable[[TV], TV]: | |
744 | """A test decorator which will skip the decorated test unless a condition is set | |
745 | ||
746 | For example: | |
747 | ||
748 | class MyTestCase(TestCase): | |
749 | @skip_unless(HAS_FOO, "Cannot test without foo") | |
750 | def test_foo(self): | |
751 | ... | |
752 | ||
753 | Args: | |
754 | condition: If true, the test will be skipped | |
755 | reason: the reason to give for skipping the test | |
756 | """ | |
757 | ||
758 | def decorator(f: TV) -> TV: | |
759 | if not condition: | |
760 | f.skip = reason # type: ignore | |
761 | return f | |
762 | ||
763 | return decorator |
24 | 24 | class DeferredCacheTestCase(TestCase): |
25 | 25 | def test_empty(self): |
26 | 26 | cache = DeferredCache("test") |
27 | failed = False | |
28 | try: | |
27 | with self.assertRaises(KeyError): | |
29 | 28 | cache.get("foo") |
30 | except KeyError: | |
31 | failed = True | |
32 | ||
33 | self.assertTrue(failed) | |
34 | 29 | |
35 | 30 | def test_hit(self): |
36 | 31 | cache = DeferredCache("test") |
154 | 149 | cache.prefill(("foo",), 123) |
155 | 150 | cache.invalidate(("foo",)) |
156 | 151 | |
157 | failed = False | |
158 | try: | |
152 | with self.assertRaises(KeyError): | |
159 | 153 | cache.get(("foo",)) |
160 | except KeyError: | |
161 | failed = True | |
162 | ||
163 | self.assertTrue(failed) | |
164 | 154 | |
165 | 155 | def test_invalidate_all(self): |
166 | 156 | cache = DeferredCache("testcache") |
214 | 204 | cache.prefill(2, "two") |
215 | 205 | cache.prefill(3, "three") # 1 will be evicted |
216 | 206 | |
217 | failed = False | |
218 | try: | |
207 | with self.assertRaises(KeyError): | |
219 | 208 | cache.get(1) |
220 | except KeyError: | |
221 | failed = True | |
222 | ||
223 | self.assertTrue(failed) | |
224 | 209 | |
225 | 210 | cache.get(2) |
226 | 211 | cache.get(3) |
238 | 223 | |
239 | 224 | cache.prefill(3, "three") |
240 | 225 | |
241 | failed = False | |
242 | try: | |
226 | with self.assertRaises(KeyError): | |
243 | 227 | cache.get(2) |
244 | except KeyError: | |
245 | failed = True | |
246 | ||
247 | self.assertTrue(failed) | |
248 | 228 | |
249 | 229 | cache.get(1) |
250 | 230 | cache.get(3) |
231 | ||
232 | def test_eviction_iterable(self): | |
233 | cache = DeferredCache( | |
234 | "test", max_entries=3, apply_cache_factor_from_config=False, iterable=True, | |
235 | ) | |
236 | ||
237 | cache.prefill(1, ["one", "two"]) | |
238 | cache.prefill(2, ["three"]) | |
239 | ||
240 | # Now access 1 again, thus causing 2 to be least-recently used | |
241 | cache.get(1) | |
242 | ||
243 | # Now add an item to the cache, which evicts 2. | |
244 | cache.prefill(3, ["four"]) | |
245 | with self.assertRaises(KeyError): | |
246 | cache.get(2) | |
247 | ||
248 | # Ensure 1 & 3 are in the cache. | |
249 | cache.get(1) | |
250 | cache.get(3) | |
251 | ||
252 | # Now access 1 again, thus causing 3 to be least-recently used | |
253 | cache.get(1) | |
254 | ||
255 | # Now add an item with multiple elements to the cache | |
256 | cache.prefill(4, ["five", "six"]) | |
257 | ||
258 | # Both 1 and 3 are evicted since there's too many elements. | |
259 | with self.assertRaises(KeyError): | |
260 | cache.get(1) | |
261 | with self.assertRaises(KeyError): | |
262 | cache.get(3) | |
263 | ||
264 | # Now add another item to fill the cache again. | |
265 | cache.prefill(5, ["seven"]) | |
266 | ||
267 | # Now access 4, thus causing 5 to be least-recently used | |
268 | cache.get(4) | |
269 | ||
270 | # Add an empty item. | |
271 | cache.prefill(6, []) | |
272 | ||
273 | # 5 gets evicted and replaced since an empty element counts as an item. | |
274 | with self.assertRaises(KeyError): | |
275 | cache.get(5) | |
276 | cache.get(4) | |
277 | cache.get(6) |
11 | 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. |
12 | 12 | # See the License for the specific language governing permissions and |
13 | 13 | # limitations under the License. |
14 | from synapse.util.iterutils import chunk_seq | |
14 | from typing import Dict, List | |
15 | ||
16 | from synapse.util.iterutils import chunk_seq, sorted_topologically | |
15 | 17 | |
16 | 18 | from tests.unittest import TestCase |
17 | 19 | |
44 | 46 | self.assertEqual( |
45 | 47 | list(parts), [], |
46 | 48 | ) |
49 | ||
50 | ||
51 | class SortTopologically(TestCase): | |
52 | def test_empty(self): | |
53 | "Test that an empty graph works correctly" | |
54 | ||
55 | graph = {} # type: Dict[int, List[int]] | |
56 | self.assertEqual(list(sorted_topologically([], graph)), []) | |
57 | ||
58 | def test_handle_empty_graph(self): | |
59 | "Test that a graph where a node doesn't have an entry is treated as empty" | |
60 | ||
61 | graph = {} # type: Dict[int, List[int]] | |
62 | ||
63 | # For disconnected nodes the output is simply sorted. | |
64 | self.assertEqual(list(sorted_topologically([1, 2], graph)), [1, 2]) | |
65 | ||
66 | def test_disconnected(self): | |
67 | "Test that a graph with no edges work" | |
68 | ||
69 | graph = {1: [], 2: []} # type: Dict[int, List[int]] | |
70 | ||
71 | # For disconnected nodes the output is simply sorted. | |
72 | self.assertEqual(list(sorted_topologically([1, 2], graph)), [1, 2]) | |
73 | ||
74 | def test_linear(self): | |
75 | "Test that a simple `4 -> 3 -> 2 -> 1` graph works" | |
76 | ||
77 | graph = {1: [], 2: [1], 3: [2], 4: [3]} # type: Dict[int, List[int]] | |
78 | ||
79 | self.assertEqual(list(sorted_topologically([4, 3, 2, 1], graph)), [1, 2, 3, 4]) | |
80 | ||
81 | def test_subset(self): | |
82 | "Test that only sorting a subset of the graph works" | |
83 | graph = {1: [], 2: [1], 3: [2], 4: [3]} # type: Dict[int, List[int]] | |
84 | ||
85 | self.assertEqual(list(sorted_topologically([4, 3], graph)), [3, 4]) | |
86 | ||
87 | def test_fork(self): | |
88 | "Test that a forked graph works" | |
89 | graph = {1: [], 2: [1], 3: [1], 4: [2, 3]} # type: Dict[int, List[int]] | |
90 | ||
91 | # Valid orderings are `[1, 3, 2, 4]` or `[1, 2, 3, 4]`, but we should | |
92 | # always get the same one. | |
93 | self.assertEqual(list(sorted_topologically([4, 3, 2, 1], graph)), [1, 2, 3, 4]) | |
94 | ||
95 | def test_duplicates(self): | |
96 | "Test that a graph with duplicate edges work" | |
97 | graph = {1: [], 2: [1, 1], 3: [2, 2], 4: [3]} # type: Dict[int, List[int]] | |
98 | ||
99 | self.assertEqual(list(sorted_topologically([4, 3, 2, 1], graph)), [1, 2, 3, 4]) | |
100 | ||
101 | def test_multiple_paths(self): | |
102 | "Test that a graph with multiple paths between two nodes work" | |
103 | graph = {1: [], 2: [1], 3: [2], 4: [3, 2, 1]} # type: Dict[int, List[int]] | |
104 | ||
105 | self.assertEqual(list(sorted_topologically([4, 3, 2, 1], graph)), [1, 2, 3, 4]) |
158 | 158 | "remote": {"per_second": 10000, "burst_count": 10000}, |
159 | 159 | }, |
160 | 160 | "saml2_enabled": False, |
161 | "public_baseurl": None, | |
162 | 161 | "default_identity_server": None, |
163 | 162 | "key_refresh_interval": 24 * 60 * 60 * 1000, |
164 | 163 | "old_signing_keys": {}, |
1 | 1 | envlist = packaging, py35, py36, py37, py38, py39, check_codestyle, check_isort |
2 | 2 | |
3 | 3 | [base] |
4 | extras = test | |
5 | 4 | deps = |
6 | 5 | python-subunit |
7 | 6 | junitxml |
24 | 23 | # install the "enum34" dependency of cryptography. |
25 | 24 | pip>=10 |
26 | 25 | |
26 | # directories/files we run the linters on | |
27 | lint_targets = | |
28 | setup.py | |
29 | synapse | |
30 | tests | |
31 | scripts | |
32 | scripts-dev | |
33 | stubs | |
34 | contrib | |
35 | synctl | |
36 | synmark | |
37 | .buildkite | |
38 | docker | |
39 | ||
40 | # default settings for all tox environments | |
27 | 41 | [testenv] |
28 | 42 | deps = |
29 | 43 | {[base]deps} |
30 | extras = all, test | |
44 | extras = | |
45 | # install the optional dependendencies for tox environments without | |
46 | # '-noextras' in their name | |
47 | !noextras: all | |
48 | test | |
31 | 49 | |
32 | 50 | setenv = |
33 | 51 | # use a postgres db for tox environments with "-postgres" in the name |
84 | 102 | [testenv:py35-old] |
85 | 103 | skip_install=True |
86 | 104 | deps = |
105 | # Ensure a version of setuptools that supports Python 3.5 is installed. | |
106 | setuptools < 51.0.0 | |
107 | ||
87 | 108 | # Old automat version for Twisted |
88 | 109 | Automat == 0.3.0 |
89 | 110 | |
95 | 116 | # Make all greater-thans equals so we test the oldest version of our direct |
96 | 117 | # dependencies, but make the pyopenssl 17.0, which can work against an |
97 | 118 | # OpenSSL 1.1 compiled cryptography (as older ones don't compile on Travis). |
98 | /bin/sh -c 'python -m synapse.python_dependencies | sed -e "s/>=/==/g" -e "s/psycopg2==2.6//" -e "s/pyopenssl==16.0.0/pyopenssl==17.0.0/" | xargs -d"\n" pip install' | |
119 | /bin/sh -c 'python -m synapse.python_dependencies | sed -e "s/>=/==/g" -e "/psycopg2/d" -e "s/pyopenssl==16.0.0/pyopenssl==17.0.0/" | xargs -d"\n" pip install' | |
99 | 120 | |
100 | 121 | # Install Synapse itself. This won't update any libraries. |
101 | 122 | pip install -e ".[test]" |
125 | 146 | [testenv:check_codestyle] |
126 | 147 | extras = lint |
127 | 148 | commands = |
128 | python -m black --check --diff . | |
129 | /bin/sh -c "flake8 synapse tests scripts scripts-dev contrib synctl {env:PEP8SUFFIX:}" | |
149 | python -m black --check --diff {[base]lint_targets} | |
150 | flake8 {[base]lint_targets} {env:PEP8SUFFIX:} | |
130 | 151 | {toxinidir}/scripts-dev/config-lint.sh |
131 | 152 | |
132 | 153 | [testenv:check_isort] |
133 | 154 | extras = lint |
134 | commands = /bin/sh -c "isort -c --df --sp setup.cfg synapse tests scripts-dev scripts" | |
155 | commands = isort -c --df --sp setup.cfg {[base]lint_targets} | |
135 | 156 | |
136 | 157 | [testenv:check-newsfragment] |
137 | 158 | skip_install = True |