Codebase list matrix-synapse / 70e3daa
New upstream version 1.26.0 Andrej Shadura 3 years ago
185 changed file(s) with 8798 addition(s) and 2802 deletion(s). Raw diff Collapse all Expand all
1414 # limitations under the License.
1515
1616 import logging
17
1718 from synapse.storage.engines import create_engine
1819
1920 logger = logging.getLogger("create_postgres_db")
1111 _trial_temp/
1212 _trial_temp*/
1313 /out
14 .DS_Store
1415
1516 # stuff that is likely to exist when you run a server locally
1617 /*.db
1718 /*.log
19 /*.log.*
1820 /*.log.config
1921 /*.pid
2022 /.python-version
0 Synapse 1.26.0 (2021-01-27)
1 ===========================
2
3 This release brings a new schema version for Synapse and rolling back to a previous
4 version is not trivial. Please review [UPGRADE.rst](UPGRADE.rst) for more details
5 on these changes and for general upgrade guidance.
6
7 No significant changes since 1.26.0rc2.
8
9
10 Synapse 1.26.0rc2 (2021-01-25)
11 ==============================
12
13 Bugfixes
14 --------
15
16 - Fix receipts and account data not being sent down sync. Introduced in v1.26.0rc1. ([\#9193](https://github.com/matrix-org/synapse/issues/9193), [\#9195](https://github.com/matrix-org/synapse/issues/9195))
17 - Fix chain cover update to handle events with duplicate auth events. Introduced in v1.26.0rc1. ([\#9210](https://github.com/matrix-org/synapse/issues/9210))
18
19
20 Internal Changes
21 ----------------
22
23 - Add an `oidc-` prefix to any `idp_id`s which are given in the `oidc_providers` configuration. ([\#9189](https://github.com/matrix-org/synapse/issues/9189))
24 - Bump minimum `psycopg2` version to v2.8. ([\#9204](https://github.com/matrix-org/synapse/issues/9204))
25
26
27 Synapse 1.26.0rc1 (2021-01-20)
28 ==============================
29
30 This release brings a new schema version for Synapse and rolling back to a previous
31 version is not trivial. Please review [UPGRADE.rst](UPGRADE.rst) for more details
32 on these changes and for general upgrade guidance.
33
34 Features
35 --------
36
37 - Add support for multiple SSO Identity Providers. ([\#9015](https://github.com/matrix-org/synapse/issues/9015), [\#9017](https://github.com/matrix-org/synapse/issues/9017), [\#9036](https://github.com/matrix-org/synapse/issues/9036), [\#9067](https://github.com/matrix-org/synapse/issues/9067), [\#9081](https://github.com/matrix-org/synapse/issues/9081), [\#9082](https://github.com/matrix-org/synapse/issues/9082), [\#9105](https://github.com/matrix-org/synapse/issues/9105), [\#9107](https://github.com/matrix-org/synapse/issues/9107), [\#9109](https://github.com/matrix-org/synapse/issues/9109), [\#9110](https://github.com/matrix-org/synapse/issues/9110), [\#9127](https://github.com/matrix-org/synapse/issues/9127), [\#9153](https://github.com/matrix-org/synapse/issues/9153), [\#9154](https://github.com/matrix-org/synapse/issues/9154), [\#9177](https://github.com/matrix-org/synapse/issues/9177))
38 - During user-interactive authentication via single-sign-on, give a better error if the user uses the wrong account on the SSO IdP. ([\#9091](https://github.com/matrix-org/synapse/issues/9091))
39 - Give the `public_baseurl` a default value, if it is not explicitly set in the configuration file. ([\#9159](https://github.com/matrix-org/synapse/issues/9159))
40 - Improve performance when calculating ignored users in large rooms. ([\#9024](https://github.com/matrix-org/synapse/issues/9024))
41 - Implement [MSC2176](https://github.com/matrix-org/matrix-doc/pull/2176) in an experimental room version. ([\#8984](https://github.com/matrix-org/synapse/issues/8984))
42 - Add an admin API for protecting local media from quarantine. ([\#9086](https://github.com/matrix-org/synapse/issues/9086))
43 - Remove a user's avatar URL and display name when deactivated with the Admin API. ([\#8932](https://github.com/matrix-org/synapse/issues/8932))
44 - Update `/_synapse/admin/v1/users/<user_id>/joined_rooms` to work for both local and remote users. ([\#8948](https://github.com/matrix-org/synapse/issues/8948))
45 - Add experimental support for handling to-device messages on worker processes. ([\#9042](https://github.com/matrix-org/synapse/issues/9042), [\#9043](https://github.com/matrix-org/synapse/issues/9043), [\#9044](https://github.com/matrix-org/synapse/issues/9044), [\#9130](https://github.com/matrix-org/synapse/issues/9130))
46 - Add experimental support for handling `/keys/claim` and `/room_keys` APIs on worker processes. ([\#9068](https://github.com/matrix-org/synapse/issues/9068))
47 - Add experimental support for handling `/devices` API on worker processes. ([\#9092](https://github.com/matrix-org/synapse/issues/9092))
48 - Add experimental support for moving off receipts and account data persistence off master. ([\#9104](https://github.com/matrix-org/synapse/issues/9104), [\#9166](https://github.com/matrix-org/synapse/issues/9166))
49
50
51 Bugfixes
52 --------
53
54 - Fix a long-standing issue where an internal server error would occur when requesting a profile over federation that did not include a display name / avatar URL. ([\#9023](https://github.com/matrix-org/synapse/issues/9023))
55 - Fix a long-standing bug where some caches could grow larger than configured. ([\#9028](https://github.com/matrix-org/synapse/issues/9028))
56 - Fix error handling during insertion of client IPs into the database. ([\#9051](https://github.com/matrix-org/synapse/issues/9051))
57 - Fix bug where we didn't correctly record CPU time spent in `on_new_event` block. ([\#9053](https://github.com/matrix-org/synapse/issues/9053))
58 - Fix a minor bug which could cause confusing error messages from invalid configurations. ([\#9054](https://github.com/matrix-org/synapse/issues/9054))
59 - Fix incorrect exit code when there is an error at startup. ([\#9059](https://github.com/matrix-org/synapse/issues/9059))
60 - Fix `JSONDecodeError` spamming the logs when sending transactions to remote servers. ([\#9070](https://github.com/matrix-org/synapse/issues/9070))
61 - Fix "Failed to send request" errors when a client provides an invalid room alias. ([\#9071](https://github.com/matrix-org/synapse/issues/9071))
62 - Fix bugs in federation catchup logic that caused outbound federation to be delayed for large servers after start up. Introduced in v1.8.0 and v1.21.0. ([\#9114](https://github.com/matrix-org/synapse/issues/9114), [\#9116](https://github.com/matrix-org/synapse/issues/9116))
63 - Fix corruption of `pushers` data when a postgres bouncer is used. ([\#9117](https://github.com/matrix-org/synapse/issues/9117))
64 - Fix minor bugs in handling the `clientRedirectUrl` parameter for SSO login. ([\#9128](https://github.com/matrix-org/synapse/issues/9128))
65 - Fix "Unhandled error in Deferred: BodyExceededMaxSize" errors when .well-known files that are too large. ([\#9108](https://github.com/matrix-org/synapse/issues/9108))
66 - Fix "UnboundLocalError: local variable 'length' referenced before assignment" errors when the response body exceeds the expected size. This bug was introduced in v1.25.0. ([\#9145](https://github.com/matrix-org/synapse/issues/9145))
67 - Fix a long-standing bug "ValueError: invalid literal for int() with base 10" when `/publicRooms` is requested with an invalid `server` parameter. ([\#9161](https://github.com/matrix-org/synapse/issues/9161))
68
69
70 Improved Documentation
71 ----------------------
72
73 - Add some extra docs for getting Synapse running on macOS. ([\#8997](https://github.com/matrix-org/synapse/issues/8997))
74 - Correct a typo in the `systemd-with-workers` documentation. ([\#9035](https://github.com/matrix-org/synapse/issues/9035))
75 - Correct a typo in `INSTALL.md`. ([\#9040](https://github.com/matrix-org/synapse/issues/9040))
76 - Add missing `user_mapping_provider` configuration to the Keycloak OIDC example. Contributed by @chris-ruecker. ([\#9057](https://github.com/matrix-org/synapse/issues/9057))
77 - Quote `pip install` packages when extras are used to avoid shells interpreting bracket characters. ([\#9151](https://github.com/matrix-org/synapse/issues/9151))
78
79
80 Deprecations and Removals
81 -------------------------
82
83 - Remove broken and unmaintained `demo/webserver.py` script. ([\#9039](https://github.com/matrix-org/synapse/issues/9039))
84
85
86 Internal Changes
87 ----------------
88
89 - Improve efficiency of large state resolutions. ([\#8868](https://github.com/matrix-org/synapse/issues/8868), [\#9029](https://github.com/matrix-org/synapse/issues/9029), [\#9115](https://github.com/matrix-org/synapse/issues/9115), [\#9118](https://github.com/matrix-org/synapse/issues/9118), [\#9124](https://github.com/matrix-org/synapse/issues/9124))
90 - Various clean-ups to the structured logging and logging context code. ([\#8939](https://github.com/matrix-org/synapse/issues/8939))
91 - Ensure rejected events get added to some metadata tables. ([\#9016](https://github.com/matrix-org/synapse/issues/9016))
92 - Ignore date-rotated homeserver logs saved to disk. ([\#9018](https://github.com/matrix-org/synapse/issues/9018))
93 - Remove an unused column from `access_tokens` table. ([\#9025](https://github.com/matrix-org/synapse/issues/9025))
94 - Add a `-noextras` factor to `tox.ini`, to support running the tests with no optional dependencies. ([\#9030](https://github.com/matrix-org/synapse/issues/9030))
95 - Fix running unit tests when optional dependencies are not installed. ([\#9031](https://github.com/matrix-org/synapse/issues/9031))
96 - Allow bumping schema version when using split out state database. ([\#9033](https://github.com/matrix-org/synapse/issues/9033))
97 - Configure the linters to run on a consistent set of files. ([\#9038](https://github.com/matrix-org/synapse/issues/9038))
98 - Various cleanups to device inbox store. ([\#9041](https://github.com/matrix-org/synapse/issues/9041))
99 - Drop unused database tables. ([\#9055](https://github.com/matrix-org/synapse/issues/9055))
100 - Remove unused `SynapseService` class. ([\#9058](https://github.com/matrix-org/synapse/issues/9058))
101 - Remove unnecessary declarations in the tests for the admin API. ([\#9063](https://github.com/matrix-org/synapse/issues/9063))
102 - Remove `SynapseRequest.get_user_agent`. ([\#9069](https://github.com/matrix-org/synapse/issues/9069))
103 - Remove redundant `Homeserver.get_ip_from_request` method. ([\#9080](https://github.com/matrix-org/synapse/issues/9080))
104 - Add type hints to media repository. ([\#9093](https://github.com/matrix-org/synapse/issues/9093))
105 - Fix the wrong arguments being passed to `BlacklistingAgentWrapper` from `MatrixFederationAgent`. Contributed by Timothy Leung. ([\#9098](https://github.com/matrix-org/synapse/issues/9098))
106 - Reduce the scope of caught exceptions in `BlacklistingAgentWrapper`. ([\#9106](https://github.com/matrix-org/synapse/issues/9106))
107 - Improve `UsernamePickerTestCase`. ([\#9112](https://github.com/matrix-org/synapse/issues/9112))
108 - Remove dependency on `distutils`. ([\#9125](https://github.com/matrix-org/synapse/issues/9125))
109 - Enforce that replication HTTP clients are called with keyword arguments only. ([\#9144](https://github.com/matrix-org/synapse/issues/9144))
110 - Fix the Python 3.5 / old dependencies build in CI. ([\#9146](https://github.com/matrix-org/synapse/issues/9146))
111 - Replace the old `perspectives` option in the Synapse docker config file template with `trusted_key_servers`. ([\#9157](https://github.com/matrix-org/synapse/issues/9157))
112
113
0114 Synapse 1.25.0 (2021-01-13)
1115 ===========================
2116
189189
190190 ```sh
191191 brew install openssl@1.1
192 export LDFLAGS=-L/usr/local/Cellar/openssl\@1.1/1.1.1d/lib/
192 export LDFLAGS="-L/usr/local/opt/openssl/lib"
193 export CPPFLAGS="-I/usr/local/opt/openssl/include"
193194 ```
194195
195196 ##### OpenSUSE
256257
257258 #### Docker images and Ansible playbooks
258259
259 There is an offical synapse image available at
260 There is an official synapse image available at
260261 <https://hub.docker.com/r/matrixdotorg/synapse> which can be used with
261262 the docker-compose file available at [contrib/docker](contrib/docker). Further
262263 information on this including configuration options is available in the README
242242 Synapse Development
243243 ===================
244244
245 Join our developer community on Matrix: [#synapse-dev:matrix.org](https://matrix.to/#/#synapse-dev:matrix.org)
245 Join our developer community on Matrix: `#synapse-dev:matrix.org <https://matrix.to/#/#synapse-dev:matrix.org>`_
246246
247247 Before setting up a development environment for synapse, make sure you have the
248248 system dependencies (such as the python header files) installed - see
278278 Ran 1337 tests in 716.064s
279279
280280 PASSED (skips=15, successes=1322)
281
282 We recommend using the demo which starts 3 federated instances running on ports `8080` - `8082`
283
284 ./demo/start.sh
285
286 (to stop, you can use `./demo/stop.sh`)
287
288 If you just want to start a single instance of the app and run it directly::
289
290 # Create the homeserver.yaml config once
291 python -m synapse.app.homeserver \
292 --server-name my.domain.name \
293 --config-path homeserver.yaml \
294 --generate-config \
295 --report-stats=[yes|no]
296
297 # Start the app
298 python -m synapse.app.homeserver --config-path homeserver.yaml
299
300
301
281302
282303 Running the Integration Tests
283304 =============================
8383 # replace `1.3.0` and `stretch` accordingly:
8484 wget https://packages.matrix.org/debian/pool/main/m/matrix-synapse-py3/matrix-synapse-py3_1.3.0+stretch1_amd64.deb
8585 dpkg -i matrix-synapse-py3_1.3.0+stretch1_amd64.deb
86
87 Upgrading to v1.26.0
88 ====================
89
90 Rolling back to v1.25.0 after a failed upgrade
91 ----------------------------------------------
92
93 v1.26.0 includes a lot of large changes. If something problematic occurs, you
94 may want to roll-back to a previous version of Synapse. Because v1.26.0 also
95 includes a new database schema version, reverting that version is also required
96 alongside the generic rollback instructions mentioned above. In short, to roll
97 back to v1.25.0 you need to:
98
99 1. Stop the server
100 2. Decrease the schema version in the database:
101
102 .. code:: sql
103
104 UPDATE schema_version SET version = 58;
105
106 3. Delete the ignored users & chain cover data:
107
108 .. code:: sql
109
110 DROP TABLE IF EXISTS ignored_users;
111 UPDATE rooms SET has_auth_chain_index = false;
112
113 For PostgreSQL run:
114
115 .. code:: sql
116
117 TRUNCATE event_auth_chain_links;
118 TRUNCATE event_auth_chains;
119
120 For SQLite run:
121
122 .. code:: sql
123
124 DELETE FROM event_auth_chain_links;
125 DELETE FROM event_auth_chains;
126
127 4. Mark the deltas as not run (so they will re-run on upgrade).
128
129 .. code:: sql
130
131 DELETE FROM applied_schema_deltas WHERE version = 59 AND file = "59/01ignored_user.py";
132 DELETE FROM applied_schema_deltas WHERE version = 59 AND file = "59/06chain_cover_index.sql";
133
134 5. Downgrade Synapse by following the instructions for your installation method
135 in the "Rolling back to older versions" section above.
86136
87137 Upgrading to v1.25.0
88138 ====================
0 matrix-synapse-py3 (1.25.0ubuntu1) UNRELEASED; urgency=medium
1
2 * Remove dependency on `python3-distutils`.
3
4 -- Richard van der Hoff <richard@matrix.org> Fri, 15 Jan 2021 12:44:19 +0000
5
06 matrix-synapse-py3 (1.25.0) stable; urgency=medium
17
28 [ Dan Callahan ]
3030 Depends:
3131 adduser,
3232 debconf,
33 python3-distutils|libpython3-stdlib (<< 3.6),
3433 ${misc:Depends},
3534 ${shlibs:Depends},
3635 ${synapse:pydepends},
+0
-59
demo/webserver.py less more
0 import argparse
1 import BaseHTTPServer
2 import os
3 import SimpleHTTPServer
4 import cgi, logging
5
6 from daemonize import Daemonize
7
8
9 class SimpleHTTPRequestHandlerWithPOST(SimpleHTTPServer.SimpleHTTPRequestHandler):
10 UPLOAD_PATH = "upload"
11
12 """
13 Accept all post request as file upload
14 """
15
16 def do_POST(self):
17
18 path = os.path.join(self.UPLOAD_PATH, os.path.basename(self.path))
19 length = self.headers["content-length"]
20 data = self.rfile.read(int(length))
21
22 with open(path, "wb") as fh:
23 fh.write(data)
24
25 self.send_response(200)
26 self.send_header("Content-Type", "application/json")
27 self.end_headers()
28
29 # Return the absolute path of the uploaded file
30 self.wfile.write('{"url":"/%s"}' % path)
31
32
33 def setup():
34 parser = argparse.ArgumentParser()
35 parser.add_argument("directory")
36 parser.add_argument("-p", "--port", dest="port", type=int, default=8080)
37 parser.add_argument("-P", "--pid-file", dest="pid", default="web.pid")
38 args = parser.parse_args()
39
40 # Get absolute path to directory to serve, as daemonize changes to '/'
41 os.chdir(args.directory)
42 dr = os.getcwd()
43
44 httpd = BaseHTTPServer.HTTPServer(("", args.port), SimpleHTTPRequestHandlerWithPOST)
45
46 def run():
47 os.chdir(dr)
48 httpd.serve_forever()
49
50 daemon = Daemonize(
51 app="synapse-webclient", pid=args.pid, action=run, auto_close_fds=False
52 )
53
54 daemon.start()
55
56
57 if __name__ == "__main__":
58 setup()
197197 key_refresh_interval: "1d" # 1 Day.
198198
199199 # The trusted servers to download signing keys from.
200 perspectives:
201 servers:
202 "matrix.org":
203 verify_keys:
204 "ed25519:auto":
205 key: "Noi6WqcDj0QmPxCNQqgezwTlBKrfqehY1u2FyWP9uYw"
200 trusted_key_servers:
201 - server_name: matrix.org
202 verify_keys:
203 "ed25519:auto": "Noi6WqcDj0QmPxCNQqgezwTlBKrfqehY1u2FyWP9uYw"
206204
207205 password_config:
208206 enabled: true
33 * [Quarantining media by ID](#quarantining-media-by-id)
44 * [Quarantining media in a room](#quarantining-media-in-a-room)
55 * [Quarantining all media of a user](#quarantining-all-media-of-a-user)
6 * [Protecting media from being quarantined](#protecting-media-from-being-quarantined)
67 - [Delete local media](#delete-local-media)
78 * [Delete a specific local media](#delete-a-specific-local-media)
89 * [Delete local media by date or size](#delete-local-media-by-date-or-size)
121122 The following fields are returned in the JSON response body:
122123
123124 * `num_quarantined`: integer - The number of media items successfully quarantined
125
126 ## Protecting media from being quarantined
127
128 This API protects a single piece of local media from being quarantined using the
129 above APIs. This is useful for sticker packs and other shared media which you do
130 not want to get quarantined, especially when
131 [quarantining media in a room](#quarantining-media-in-a-room).
132
133 Request:
134
135 ```
136 POST /_synapse/admin/v1/media/protect/<media_id>
137
138 {}
139 ```
140
141 Where `media_id` is in the form of `abcdefg12345...`.
142
143 Response:
144
145 ```json
146 {}
147 ```
124148
125149 # Delete local media
126150 This API deletes the *local* media from the disk of your own server.
9797
9898 - ``deactivated``, optional. If unspecified, deactivation state will be left
9999 unchanged on existing accounts and set to ``false`` for new accounts.
100 A user cannot be erased by deactivating with this API. For details on deactivating users see
101 `Deactivate Account <#deactivate-account>`_.
100102
101103 If the user already exists then optional parameters default to the current value.
102104
247249 The erase parameter is optional and defaults to ``false``.
248250 An empty body may be passed for backwards compatibility.
249251
252 The following actions are performed when deactivating an user:
253
254 - Try to unpind 3PIDs from the identity server
255 - Remove all 3PIDs from the homeserver
256 - Delete all devices and E2EE keys
257 - Delete all access tokens
258 - Delete the password hash
259 - Removal from all rooms the user is a member of
260 - Remove the user from the user directory
261 - Reject all pending invites
262 - Remove all account validity information related to the user
263
264 The following additional actions are performed during deactivation if``erase``
265 is set to ``true``:
266
267 - Remove the user's display name
268 - Remove the user's avatar URL
269 - Mark the user as erased
270
250271
251272 Reset password
252273 ==============
335356 ],
336357 "total": 2
337358 }
359
360 The server returns the list of rooms of which the user and the server
361 are member. If the user is local, all the rooms of which the user is
362 member are returned.
338363
339364 **Parameters**
340365
0 digraph auth {
1 nodesep=0.5;
2 rankdir="RL";
3
4 C [label="Create (1,1)"];
5
6 BJ [label="Bob's Join (2,1)", color=red];
7 BJ2 [label="Bob's Join (2,2)", color=red];
8 BJ2 -> BJ [color=red, dir=none];
9
10 subgraph cluster_foo {
11 A1 [label="Alice's invite (4,1)", color=blue];
12 A2 [label="Alice's Join (4,2)", color=blue];
13 A3 [label="Alice's Join (4,3)", color=blue];
14 A3 -> A2 -> A1 [color=blue, dir=none];
15 color=none;
16 }
17
18 PL1 [label="Power Level (3,1)", color=darkgreen];
19 PL2 [label="Power Level (3,2)", color=darkgreen];
20 PL2 -> PL1 [color=darkgreen, dir=none];
21
22 {rank = same; C; BJ; PL1; A1;}
23
24 A1 -> C [color=grey];
25 A1 -> BJ [color=grey];
26 PL1 -> C [color=grey];
27 BJ2 -> PL1 [penwidth=2];
28
29 A3 -> PL2 [penwidth=2];
30 A1 -> PL1 -> BJ -> C [penwidth=2];
31 }
0 # Auth Chain Difference Algorithm
1
2 The auth chain difference algorithm is used by V2 state resolution, where a
3 naive implementation can be a significant source of CPU and DB usage.
4
5 ### Definitions
6
7 A *state set* is a set of state events; e.g. the input of a state resolution
8 algorithm is a collection of state sets.
9
10 The *auth chain* of a set of events are all the events' auth events and *their*
11 auth events, recursively (i.e. the events reachable by walking the graph induced
12 by an event's auth events links).
13
14 The *auth chain difference* of a collection of state sets is the union minus the
15 intersection of the sets of auth chains corresponding to the state sets, i.e an
16 event is in the auth chain difference if it is reachable by walking the auth
17 event graph from at least one of the state sets but not from *all* of the state
18 sets.
19
20 ## Breadth First Walk Algorithm
21
22 A way of calculating the auth chain difference without calculating the full auth
23 chains for each state set is to do a parallel breadth first walk (ordered by
24 depth) of each state set's auth chain. By tracking which events are reachable
25 from each state set we can finish early if every pending event is reachable from
26 every state set.
27
28 This can work well for state sets that have a small auth chain difference, but
29 can be very inefficient for larger differences. However, this algorithm is still
30 used if we don't have a chain cover index for the room (e.g. because we're in
31 the process of indexing it).
32
33 ## Chain Cover Index
34
35 Synapse computes auth chain differences by pre-computing a "chain cover" index
36 for the auth chain in a room, allowing efficient reachability queries like "is
37 event A in the auth chain of event B". This is done by assigning every event a
38 *chain ID* and *sequence number* (e.g. `(5,3)`), and having a map of *links*
39 between chains (e.g. `(5,3) -> (2,4)`) such that A is reachable by B (i.e. `A`
40 is in the auth chain of `B`) if and only if either:
41
42 1. A and B have the same chain ID and `A`'s sequence number is less than `B`'s
43 sequence number; or
44 2. there is a link `L` between `B`'s chain ID and `A`'s chain ID such that
45 `L.start_seq_no` <= `B.seq_no` and `A.seq_no` <= `L.end_seq_no`.
46
47 There are actually two potential implementations, one where we store links from
48 each chain to every other reachable chain (the transitive closure of the links
49 graph), and one where we remove redundant links (the transitive reduction of the
50 links graph) e.g. if we have chains `C3 -> C2 -> C1` then the link `C3 -> C1`
51 would not be stored. Synapse uses the former implementations so that it doesn't
52 need to recurse to test reachability between chains.
53
54 ### Example
55
56 An example auth graph would look like the following, where chains have been
57 formed based on type/state_key and are denoted by colour and are labelled with
58 `(chain ID, sequence number)`. Links are denoted by the arrows (links in grey
59 are those that would be remove in the second implementation described above).
60
61 ![Example](auth_chain_diff.dot.png)
62
63 Note that we don't include all links between events and their auth events, as
64 most of those links would be redundant. For example, all events point to the
65 create event, but each chain only needs the one link from it's base to the
66 create event.
67
68 ## Using the Index
69
70 This index can be used to calculate the auth chain difference of the state sets
71 by looking at the chain ID and sequence numbers reachable from each state set:
72
73 1. For every state set lookup the chain ID/sequence numbers of each state event
74 2. Use the index to find all chains and the maximum sequence number reachable
75 from each state set.
76 3. The auth chain difference is then all events in each chain that have sequence
77 numbers between the maximum sequence number reachable from *any* state set and
78 the minimum reachable by *all* state sets (if any).
79
80 Note that steps 2 is effectively calculating the auth chain for each state set
81 (in terms of chain IDs and sequence numbers), and step 3 is calculating the
82 difference between the union and intersection of the auth chains.
83
84 ### Worked Example
85
86 For example, given the above graph, we can calculate the difference between
87 state sets consisting of:
88
89 1. `S1`: Alice's invite `(4,1)` and Bob's second join `(2,2)`; and
90 2. `S2`: Alice's second join `(4,3)` and Bob's first join `(2,1)`.
91
92 Using the index we see that the following auth chains are reachable from each
93 state set:
94
95 1. `S1`: `(1,1)`, `(2,2)`, `(3,1)` & `(4,1)`
96 2. `S2`: `(1,1)`, `(2,1)`, `(3,2)` & `(4,3)`
97
98 And so, for each the ranges that are in the auth chain difference:
99 1. Chain 1: None, (since everything can reach the create event).
100 2. Chain 2: The range `(1, 2]` (i.e. just `2`), as `1` is reachable by all state
101 sets and the maximum reachable is `2` (corresponding to Bob's second join).
102 3. Chain 3: Similarly the range `(1, 2]` (corresponding to the second power
103 level).
104 4. Chain 4: The range `(1, 3]` (corresponding to both of Alice's joins).
105
106 So the final result is: Bob's second join `(2,2)`, the second power level
107 `(3,2)` and both of Alice's joins `(4,2)` & `(4,3)`.
4141 * For other installation mechanisms, see the documentation provided by the
4242 maintainer.
4343
44 To enable the OpenID integration, you should then add an `oidc_config` section
45 to your configuration file (or uncomment the `enabled: true` line in the
46 existing section). See [sample_config.yaml](./sample_config.yaml) for some
47 sample settings, as well as the text below for example configurations for
48 specific providers.
44 To enable the OpenID integration, you should then add a section to the `oidc_providers`
45 setting in your configuration file (or uncomment one of the existing examples).
46 See [sample_config.yaml](./sample_config.yaml) for some sample settings, as well as
47 the text below for example configurations for specific providers.
4948
5049 ## Sample configs
5150
6160 Edit your Synapse config file and change the `oidc_config` section:
6261
6362 ```yaml
64 oidc_config:
65 enabled: true
66 issuer: "https://login.microsoftonline.com/<tenant id>/v2.0"
67 client_id: "<client id>"
68 client_secret: "<client secret>"
69 scopes: ["openid", "profile"]
70 authorization_endpoint: "https://login.microsoftonline.com/<tenant id>/oauth2/v2.0/authorize"
71 token_endpoint: "https://login.microsoftonline.com/<tenant id>/oauth2/v2.0/token"
72 userinfo_endpoint: "https://graph.microsoft.com/oidc/userinfo"
73
74 user_mapping_provider:
75 config:
76 localpart_template: "{{ user.preferred_username.split('@')[0] }}"
77 display_name_template: "{{ user.name }}"
63 oidc_providers:
64 - idp_id: microsoft
65 idp_name: Microsoft
66 issuer: "https://login.microsoftonline.com/<tenant id>/v2.0"
67 client_id: "<client id>"
68 client_secret: "<client secret>"
69 scopes: ["openid", "profile"]
70 authorization_endpoint: "https://login.microsoftonline.com/<tenant id>/oauth2/v2.0/authorize"
71 token_endpoint: "https://login.microsoftonline.com/<tenant id>/oauth2/v2.0/token"
72 userinfo_endpoint: "https://graph.microsoft.com/oidc/userinfo"
73
74 user_mapping_provider:
75 config:
76 localpart_template: "{{ user.preferred_username.split('@')[0] }}"
77 display_name_template: "{{ user.name }}"
7878 ```
7979
8080 ### [Dex][dex-idp]
102102 Synapse config:
103103
104104 ```yaml
105 oidc_config:
106 enabled: true
107 skip_verification: true # This is needed as Dex is served on an insecure endpoint
108 issuer: "http://127.0.0.1:5556/dex"
109 client_id: "synapse"
110 client_secret: "secret"
111 scopes: ["openid", "profile"]
112 user_mapping_provider:
113 config:
114 localpart_template: "{{ user.name }}"
115 display_name_template: "{{ user.name|capitalize }}"
105 oidc_providers:
106 - idp_id: dex
107 idp_name: "My Dex server"
108 skip_verification: true # This is needed as Dex is served on an insecure endpoint
109 issuer: "http://127.0.0.1:5556/dex"
110 client_id: "synapse"
111 client_secret: "secret"
112 scopes: ["openid", "profile"]
113 user_mapping_provider:
114 config:
115 localpart_template: "{{ user.name }}"
116 display_name_template: "{{ user.name|capitalize }}"
116117 ```
117118 ### [Keycloak][keycloak-idp]
118119
151152 8. Copy Secret
152153
153154 ```yaml
154 oidc_config:
155 enabled: true
156 issuer: "https://127.0.0.1:8443/auth/realms/{realm_name}"
157 client_id: "synapse"
158 client_secret: "copy secret generated from above"
159 scopes: ["openid", "profile"]
155 oidc_providers:
156 - idp_id: keycloak
157 idp_name: "My KeyCloak server"
158 issuer: "https://127.0.0.1:8443/auth/realms/{realm_name}"
159 client_id: "synapse"
160 client_secret: "copy secret generated from above"
161 scopes: ["openid", "profile"]
162 user_mapping_provider:
163 config:
164 localpart_template: "{{ user.preferred_username }}"
165 display_name_template: "{{ user.name }}"
160166 ```
161167 ### [Auth0][auth0]
162168
186192 Synapse config:
187193
188194 ```yaml
189 oidc_config:
190 enabled: true
191 issuer: "https://your-tier.eu.auth0.com/" # TO BE FILLED
192 client_id: "your-client-id" # TO BE FILLED
193 client_secret: "your-client-secret" # TO BE FILLED
194 scopes: ["openid", "profile"]
195 user_mapping_provider:
196 config:
197 localpart_template: "{{ user.preferred_username }}"
198 display_name_template: "{{ user.name }}"
195 oidc_providers:
196 - idp_id: auth0
197 idp_name: Auth0
198 issuer: "https://your-tier.eu.auth0.com/" # TO BE FILLED
199 client_id: "your-client-id" # TO BE FILLED
200 client_secret: "your-client-secret" # TO BE FILLED
201 scopes: ["openid", "profile"]
202 user_mapping_provider:
203 config:
204 localpart_template: "{{ user.preferred_username }}"
205 display_name_template: "{{ user.name }}"
199206 ```
200207
201208 ### GitHub
214221 Synapse config:
215222
216223 ```yaml
217 oidc_config:
218 enabled: true
219 discover: false
220 issuer: "https://github.com/"
221 client_id: "your-client-id" # TO BE FILLED
222 client_secret: "your-client-secret" # TO BE FILLED
223 authorization_endpoint: "https://github.com/login/oauth/authorize"
224 token_endpoint: "https://github.com/login/oauth/access_token"
225 userinfo_endpoint: "https://api.github.com/user"
226 scopes: ["read:user"]
227 user_mapping_provider:
228 config:
229 subject_claim: "id"
230 localpart_template: "{{ user.login }}"
231 display_name_template: "{{ user.name }}"
224 oidc_providers:
225 - idp_id: github
226 idp_name: Github
227 discover: false
228 issuer: "https://github.com/"
229 client_id: "your-client-id" # TO BE FILLED
230 client_secret: "your-client-secret" # TO BE FILLED
231 authorization_endpoint: "https://github.com/login/oauth/authorize"
232 token_endpoint: "https://github.com/login/oauth/access_token"
233 userinfo_endpoint: "https://api.github.com/user"
234 scopes: ["read:user"]
235 user_mapping_provider:
236 config:
237 subject_claim: "id"
238 localpart_template: "{{ user.login }}"
239 display_name_template: "{{ user.name }}"
232240 ```
233241
234242 ### [Google][google-idp]
238246 2. add an "OAuth Client ID" for a Web Application under "Credentials".
239247 3. Copy the Client ID and Client Secret, and add the following to your synapse config:
240248 ```yaml
241 oidc_config:
242 enabled: true
243 issuer: "https://accounts.google.com/"
244 client_id: "your-client-id" # TO BE FILLED
245 client_secret: "your-client-secret" # TO BE FILLED
246 scopes: ["openid", "profile"]
247 user_mapping_provider:
248 config:
249 localpart_template: "{{ user.given_name|lower }}"
250 display_name_template: "{{ user.name }}"
249 oidc_providers:
250 - idp_id: google
251 idp_name: Google
252 issuer: "https://accounts.google.com/"
253 client_id: "your-client-id" # TO BE FILLED
254 client_secret: "your-client-secret" # TO BE FILLED
255 scopes: ["openid", "profile"]
256 user_mapping_provider:
257 config:
258 localpart_template: "{{ user.given_name|lower }}"
259 display_name_template: "{{ user.name }}"
251260 ```
252261 4. Back in the Google console, add this Authorized redirect URI: `[synapse
253262 public baseurl]/_synapse/oidc/callback`.
261270 Synapse config:
262271
263272 ```yaml
264 oidc_config:
265 enabled: true
266 issuer: "https://id.twitch.tv/oauth2/"
267 client_id: "your-client-id" # TO BE FILLED
268 client_secret: "your-client-secret" # TO BE FILLED
269 client_auth_method: "client_secret_post"
270 user_mapping_provider:
271 config:
272 localpart_template: "{{ user.preferred_username }}"
273 display_name_template: "{{ user.name }}"
273 oidc_providers:
274 - idp_id: twitch
275 idp_name: Twitch
276 issuer: "https://id.twitch.tv/oauth2/"
277 client_id: "your-client-id" # TO BE FILLED
278 client_secret: "your-client-secret" # TO BE FILLED
279 client_auth_method: "client_secret_post"
280 user_mapping_provider:
281 config:
282 localpart_template: "{{ user.preferred_username }}"
283 display_name_template: "{{ user.name }}"
274284 ```
275285
276286 ### GitLab
282292 Synapse config:
283293
284294 ```yaml
285 oidc_config:
286 enabled: true
287 issuer: "https://gitlab.com/"
288 client_id: "your-client-id" # TO BE FILLED
289 client_secret: "your-client-secret" # TO BE FILLED
290 client_auth_method: "client_secret_post"
291 scopes: ["openid", "read_user"]
292 user_profile_method: "userinfo_endpoint"
293 user_mapping_provider:
294 config:
295 localpart_template: '{{ user.nickname }}'
296 display_name_template: '{{ user.name }}'
297 ```
295 oidc_providers:
296 - idp_id: gitlab
297 idp_name: Gitlab
298 issuer: "https://gitlab.com/"
299 client_id: "your-client-id" # TO BE FILLED
300 client_secret: "your-client-secret" # TO BE FILLED
301 client_auth_method: "client_secret_post"
302 scopes: ["openid", "read_user"]
303 user_profile_method: "userinfo_endpoint"
304 user_mapping_provider:
305 config:
306 localpart_template: '{{ user.nickname }}'
307 display_name_template: '{{ user.name }}'
308 ```
1717 virtualenv](../INSTALL.md#installing-from-source), you can install
1818 the library with:
1919
20 ~/synapse/env/bin/pip install matrix-synapse[postgres]
20 ~/synapse/env/bin/pip install "matrix-synapse[postgres]"
2121
2222 (substituting the path to your virtualenv for `~/synapse/env`, if
2323 you used a different path). You will require the postgres
6666 #
6767 #web_client_location: https://riot.example.com/
6868
69 # The public-facing base URL that clients use to access this HS
70 # (not including _matrix/...). This is the same URL a user would
71 # enter into the 'custom HS URL' field on their client. If you
72 # use synapse with a reverse proxy, this should be the URL to reach
73 # synapse via the proxy.
69 # The public-facing base URL that clients use to access this Homeserver (not
70 # including _matrix/...). This is the same URL a user might enter into the
71 # 'Custom Homeserver URL' field on their client. If you use Synapse with a
72 # reverse proxy, this should be the URL to reach Synapse via the proxy.
73 # Otherwise, it should be the URL to reach Synapse's client HTTP listener (see
74 # 'listeners' below).
75 #
76 # If this is left unset, it defaults to 'https://<server_name>/'. (Note that
77 # that will not work unless you configure Synapse or a reverse-proxy to listen
78 # on port 443.)
7479 #
7580 #public_baseurl: https://example.com/
7681
11491154 # send an email to the account's email address with a renewal link. By
11501155 # default, no such emails are sent.
11511156 #
1152 # If you enable this setting, you will also need to fill out the 'email' and
1153 # 'public_baseurl' configuration sections.
1157 # If you enable this setting, you will also need to fill out the 'email'
1158 # configuration section. You should also check that 'public_baseurl' is set
1159 # correctly.
11541160 #
11551161 #renew_at: 1w
11561162
12411247 # The identity server which we suggest that clients should use when users log
12421248 # in on this server.
12431249 #
1244 # (By default, no suggestion is made, so it is left up to the client.
1245 # This setting is ignored unless public_baseurl is also set.)
1250 # (By default, no suggestion is made, so it is left up to the client.)
12461251 #
12471252 #default_identity_server: https://matrix.org
12481253
12661271 # Servers handling the these requests must answer the `/requestToken` endpoints defined
12671272 # by the Matrix Identity Service API specification:
12681273 # https://matrix.org/docs/spec/identity_service/latest
1269 #
1270 # If a delegate is specified, the config option public_baseurl must also be filled out.
12711274 #
12721275 account_threepid_delegates:
12731276 #email: https://example.com # Delegate email sending to example.com
17081711 #idp_entityid: 'https://our_idp/entityid'
17091712
17101713
1711 # Enable OpenID Connect (OIDC) / OAuth 2.0 for registration and login.
1714 # List of OpenID Connect (OIDC) / OAuth 2.0 identity providers, for registration
1715 # and login.
1716 #
1717 # Options for each entry include:
1718 #
1719 # idp_id: a unique identifier for this identity provider. Used internally
1720 # by Synapse; should be a single word such as 'github'.
1721 #
1722 # Note that, if this is changed, users authenticating via that provider
1723 # will no longer be recognised as the same user!
1724 #
1725 # idp_name: A user-facing name for this identity provider, which is used to
1726 # offer the user a choice of login mechanisms.
1727 #
1728 # idp_icon: An optional icon for this identity provider, which is presented
1729 # by identity picker pages. If given, must be an MXC URI of the format
1730 # mxc://<server-name>/<media-id>. (An easy way to obtain such an MXC URI
1731 # is to upload an image to an (unencrypted) room and then copy the "url"
1732 # from the source of the event.)
1733 #
1734 # discover: set to 'false' to disable the use of the OIDC discovery mechanism
1735 # to discover endpoints. Defaults to true.
1736 #
1737 # issuer: Required. The OIDC issuer. Used to validate tokens and (if discovery
1738 # is enabled) to discover the provider's endpoints.
1739 #
1740 # client_id: Required. oauth2 client id to use.
1741 #
1742 # client_secret: Required. oauth2 client secret to use.
1743 #
1744 # client_auth_method: auth method to use when exchanging the token. Valid
1745 # values are 'client_secret_basic' (default), 'client_secret_post' and
1746 # 'none'.
1747 #
1748 # scopes: list of scopes to request. This should normally include the "openid"
1749 # scope. Defaults to ["openid"].
1750 #
1751 # authorization_endpoint: the oauth2 authorization endpoint. Required if
1752 # provider discovery is disabled.
1753 #
1754 # token_endpoint: the oauth2 token endpoint. Required if provider discovery is
1755 # disabled.
1756 #
1757 # userinfo_endpoint: the OIDC userinfo endpoint. Required if discovery is
1758 # disabled and the 'openid' scope is not requested.
1759 #
1760 # jwks_uri: URI where to fetch the JWKS. Required if discovery is disabled and
1761 # the 'openid' scope is used.
1762 #
1763 # skip_verification: set to 'true' to skip metadata verification. Use this if
1764 # you are connecting to a provider that is not OpenID Connect compliant.
1765 # Defaults to false. Avoid this in production.
1766 #
1767 # user_profile_method: Whether to fetch the user profile from the userinfo
1768 # endpoint. Valid values are: 'auto' or 'userinfo_endpoint'.
1769 #
1770 # Defaults to 'auto', which fetches the userinfo endpoint if 'openid' is
1771 # included in 'scopes'. Set to 'userinfo_endpoint' to always fetch the
1772 # userinfo endpoint.
1773 #
1774 # allow_existing_users: set to 'true' to allow a user logging in via OIDC to
1775 # match a pre-existing account instead of failing. This could be used if
1776 # switching from password logins to OIDC. Defaults to false.
1777 #
1778 # user_mapping_provider: Configuration for how attributes returned from a OIDC
1779 # provider are mapped onto a matrix user. This setting has the following
1780 # sub-properties:
1781 #
1782 # module: The class name of a custom mapping module. Default is
1783 # 'synapse.handlers.oidc_handler.JinjaOidcMappingProvider'.
1784 # See https://github.com/matrix-org/synapse/blob/master/docs/sso_mapping_providers.md#openid-mapping-providers
1785 # for information on implementing a custom mapping provider.
1786 #
1787 # config: Configuration for the mapping provider module. This section will
1788 # be passed as a Python dictionary to the user mapping provider
1789 # module's `parse_config` method.
1790 #
1791 # For the default provider, the following settings are available:
1792 #
1793 # sub: name of the claim containing a unique identifier for the
1794 # user. Defaults to 'sub', which OpenID Connect compliant
1795 # providers should provide.
1796 #
1797 # localpart_template: Jinja2 template for the localpart of the MXID.
1798 # If this is not set, the user will be prompted to choose their
1799 # own username.
1800 #
1801 # display_name_template: Jinja2 template for the display name to set
1802 # on first login. If unset, no displayname will be set.
1803 #
1804 # extra_attributes: a map of Jinja2 templates for extra attributes
1805 # to send back to the client during login.
1806 # Note that these are non-standard and clients will ignore them
1807 # without modifications.
1808 #
1809 # When rendering, the Jinja2 templates are given a 'user' variable,
1810 # which is set to the claims returned by the UserInfo Endpoint and/or
1811 # in the ID Token.
17121812 #
17131813 # See https://github.com/matrix-org/synapse/blob/master/docs/openid.md
1714 # for some example configurations.
1715 #
1716 oidc_config:
1717 # Uncomment the following to enable authorization against an OpenID Connect
1718 # server. Defaults to false.
1719 #
1720 #enabled: true
1721
1722 # Uncomment the following to disable use of the OIDC discovery mechanism to
1723 # discover endpoints. Defaults to true.
1724 #
1725 #discover: false
1726
1727 # the OIDC issuer. Used to validate tokens and (if discovery is enabled) to
1728 # discover the provider's endpoints.
1729 #
1730 # Required if 'enabled' is true.
1731 #
1732 #issuer: "https://accounts.example.com/"
1733
1734 # oauth2 client id to use.
1735 #
1736 # Required if 'enabled' is true.
1737 #
1738 #client_id: "provided-by-your-issuer"
1739
1740 # oauth2 client secret to use.
1741 #
1742 # Required if 'enabled' is true.
1743 #
1744 #client_secret: "provided-by-your-issuer"
1745
1746 # auth method to use when exchanging the token.
1747 # Valid values are 'client_secret_basic' (default), 'client_secret_post' and
1748 # 'none'.
1749 #
1750 #client_auth_method: client_secret_post
1751
1752 # list of scopes to request. This should normally include the "openid" scope.
1753 # Defaults to ["openid"].
1754 #
1755 #scopes: ["openid", "profile"]
1756
1757 # the oauth2 authorization endpoint. Required if provider discovery is disabled.
1758 #
1759 #authorization_endpoint: "https://accounts.example.com/oauth2/auth"
1760
1761 # the oauth2 token endpoint. Required if provider discovery is disabled.
1762 #
1763 #token_endpoint: "https://accounts.example.com/oauth2/token"
1764
1765 # the OIDC userinfo endpoint. Required if discovery is disabled and the
1766 # "openid" scope is not requested.
1767 #
1768 #userinfo_endpoint: "https://accounts.example.com/userinfo"
1769
1770 # URI where to fetch the JWKS. Required if discovery is disabled and the
1771 # "openid" scope is used.
1772 #
1773 #jwks_uri: "https://accounts.example.com/.well-known/jwks.json"
1774
1775 # Uncomment to skip metadata verification. Defaults to false.
1776 #
1777 # Use this if you are connecting to a provider that is not OpenID Connect
1778 # compliant.
1779 # Avoid this in production.
1780 #
1781 #skip_verification: true
1782
1783 # Whether to fetch the user profile from the userinfo endpoint. Valid
1784 # values are: "auto" or "userinfo_endpoint".
1785 #
1786 # Defaults to "auto", which fetches the userinfo endpoint if "openid" is included
1787 # in `scopes`. Uncomment the following to always fetch the userinfo endpoint.
1788 #
1789 #user_profile_method: "userinfo_endpoint"
1790
1791 # Uncomment to allow a user logging in via OIDC to match a pre-existing account instead
1792 # of failing. This could be used if switching from password logins to OIDC. Defaults to false.
1793 #
1794 #allow_existing_users: true
1795
1796 # An external module can be provided here as a custom solution to mapping
1797 # attributes returned from a OIDC provider onto a matrix user.
1798 #
1799 user_mapping_provider:
1800 # The custom module's class. Uncomment to use a custom module.
1801 # Default is 'synapse.handlers.oidc_handler.JinjaOidcMappingProvider'.
1802 #
1803 # See https://github.com/matrix-org/synapse/blob/master/docs/sso_mapping_providers.md#openid-mapping-providers
1804 # for information on implementing a custom mapping provider.
1805 #
1806 #module: mapping_provider.OidcMappingProvider
1807
1808 # Custom configuration values for the module. This section will be passed as
1809 # a Python dictionary to the user mapping provider module's `parse_config`
1810 # method.
1811 #
1812 # The examples below are intended for the default provider: they should be
1813 # changed if using a custom provider.
1814 #
1815 config:
1816 # name of the claim containing a unique identifier for the user.
1817 # Defaults to `sub`, which OpenID Connect compliant providers should provide.
1818 #
1819 #subject_claim: "sub"
1820
1821 # Jinja2 template for the localpart of the MXID.
1822 #
1823 # When rendering, this template is given the following variables:
1824 # * user: The claims returned by the UserInfo Endpoint and/or in the ID
1825 # Token
1826 #
1827 # If this is not set, the user will be prompted to choose their
1828 # own username.
1829 #
1830 #localpart_template: "{{ user.preferred_username }}"
1831
1832 # Jinja2 template for the display name to set on first login.
1833 #
1834 # If unset, no displayname will be set.
1835 #
1836 #display_name_template: "{{ user.given_name }} {{ user.last_name }}"
1837
1838 # Jinja2 templates for extra attributes to send back to the client during
1839 # login.
1840 #
1841 # Note that these are non-standard and clients will ignore them without modifications.
1842 #
1843 #extra_attributes:
1844 #birthdate: "{{ user.birthdate }}"
1845
1814 # for information on how to configure these options.
1815 #
1816 # For backwards compatibility, it is also possible to configure a single OIDC
1817 # provider via an 'oidc_config' setting. This is now deprecated and admins are
1818 # advised to migrate to the 'oidc_providers' format. (When doing that migration,
1819 # use 'oidc' for the idp_id to ensure that existing users continue to be
1820 # recognised.)
1821 #
1822 oidc_providers:
1823 # Generic example
1824 #
1825 #- idp_id: my_idp
1826 # idp_name: "My OpenID provider"
1827 # idp_icon: "mxc://example.com/mediaid"
1828 # discover: false
1829 # issuer: "https://accounts.example.com/"
1830 # client_id: "provided-by-your-issuer"
1831 # client_secret: "provided-by-your-issuer"
1832 # client_auth_method: client_secret_post
1833 # scopes: ["openid", "profile"]
1834 # authorization_endpoint: "https://accounts.example.com/oauth2/auth"
1835 # token_endpoint: "https://accounts.example.com/oauth2/token"
1836 # userinfo_endpoint: "https://accounts.example.com/userinfo"
1837 # jwks_uri: "https://accounts.example.com/.well-known/jwks.json"
1838 # skip_verification: true
1839
1840 # For use with Keycloak
1841 #
1842 #- idp_id: keycloak
1843 # idp_name: Keycloak
1844 # issuer: "https://127.0.0.1:8443/auth/realms/my_realm_name"
1845 # client_id: "synapse"
1846 # client_secret: "copy secret generated in Keycloak UI"
1847 # scopes: ["openid", "profile"]
1848
1849 # For use with Github
1850 #
1851 #- idp_id: github
1852 # idp_name: Github
1853 # discover: false
1854 # issuer: "https://github.com/"
1855 # client_id: "your-client-id" # TO BE FILLED
1856 # client_secret: "your-client-secret" # TO BE FILLED
1857 # authorization_endpoint: "https://github.com/login/oauth/authorize"
1858 # token_endpoint: "https://github.com/login/oauth/access_token"
1859 # userinfo_endpoint: "https://api.github.com/user"
1860 # scopes: ["read:user"]
1861 # user_mapping_provider:
1862 # config:
1863 # subject_claim: "id"
1864 # localpart_template: "{ user.login }"
1865 # display_name_template: "{ user.name }"
18461866
18471867
18481868 # Enable Central Authentication Service (CAS) for registration and login.
18921912 # phishing attacks from evil.site. To avoid this, include a slash after the
18931913 # hostname: "https://my.client/".
18941914 #
1895 # If public_baseurl is set, then the login fallback page (used by clients
1896 # that don't natively support the required login flows) is whitelisted in
1897 # addition to any URLs in this list.
1915 # The login fallback page (used by clients that don't natively support the
1916 # required login flows) is automatically whitelisted in addition to any URLs
1917 # in this list.
18981918 #
18991919 # By default, this list is empty.
19001920 #
19071927 # directory, default templates from within the Synapse package will be used.
19081928 #
19091929 # Synapse will look for the following templates in this directory:
1930 #
1931 # * HTML page to prompt the user to choose an Identity Provider during
1932 # login: 'sso_login_idp_picker.html'.
1933 #
1934 # This is only used if multiple SSO Identity Providers are configured.
1935 #
1936 # When rendering, this template is given the following variables:
1937 # * redirect_url: the URL that the user will be redirected to after
1938 # login. Needs manual escaping (see
1939 # https://jinja.palletsprojects.com/en/2.11.x/templates/#html-escaping).
1940 #
1941 # * server_name: the homeserver's name.
1942 #
1943 # * providers: a list of available Identity Providers. Each element is
1944 # an object with the following attributes:
1945 # * idp_id: unique identifier for the IdP
1946 # * idp_name: user-facing name for the IdP
1947 #
1948 # The rendered HTML page should contain a form which submits its results
1949 # back as a GET request, with the following query parameters:
1950 #
1951 # * redirectUrl: the client redirect URI (ie, the `redirect_url` passed
1952 # to the template)
1953 #
1954 # * idp: the 'idp_id' of the chosen IDP.
19101955 #
19111956 # * HTML page for a confirmation step before redirecting back to the client
19121957 # with the login token: 'sso_redirect_confirm.html'.
19421987 # (see https://matrix.org/docs/spec/client_server/r0.6.0#fallback).
19431988 #
19441989 # This template has no additional variables.
1990 #
1991 # * HTML page shown after a user-interactive authentication session which
1992 # does not map correctly onto the expected user: 'sso_auth_bad_user.html'.
1993 #
1994 # When rendering, this template is given the following variables:
1995 # * server_name: the homeserver's name.
1996 # * user_id_to_verify: the MXID of the user that we are trying to
1997 # validate.
19451998 #
19461999 # * HTML page shown during single sign-on if a deactivated user (according to Synapse's database)
19472000 # attempts to login: 'sso_account_deactivated.html'.
3030 1. Adjust synapse configuration files as above.
3131 1. Copy the `*.service` and `*.target` files in [system](system) to
3232 `/etc/systemd/system`.
33 1. Run `systemctl deamon-reload` to tell systemd to load the new unit files.
33 1. Run `systemctl daemon-reload` to tell systemd to load the new unit files.
3434 1. Run `systemctl enable matrix-synapse.service`. This will configure the
3535 synapse master process to be started as part of the `matrix-synapse.target`
3636 target.
1414 workers only work with PostgreSQL-based Synapse deployments. SQLite should only
1515 be used for demo purposes and any admin considering workers should already be
1616 running PostgreSQL.
17
18 See also https://matrix.org/blog/2020/11/03/how-we-fixed-synapses-scalability
19 for a higher level overview.
1720
1821 ## Main process/worker communication
1922
5558 virtualenv, these can be installed with:
5659
5760 ```sh
58 pip install matrix-synapse[redis]
61 pip install "matrix-synapse[redis]"
5962 ```
6063
6164 Note that these dependencies are included when synapse is installed with `pip
213216 ^/_matrix/client/(api/v1|r0|unstable)/rooms/.*/members$
214217 ^/_matrix/client/(api/v1|r0|unstable)/rooms/.*/state$
215218 ^/_matrix/client/(api/v1|r0|unstable)/account/3pid$
219 ^/_matrix/client/(api/v1|r0|unstable)/devices$
216220 ^/_matrix/client/(api/v1|r0|unstable)/keys/query$
217221 ^/_matrix/client/(api/v1|r0|unstable)/keys/changes$
218222 ^/_matrix/client/versions$
9999 synapse/util/async_helpers.py,
100100 synapse/util/caches,
101101 synapse/util/metrics.py,
102 synapse/util/stringutils.py,
102103 tests/replication,
103104 tests/test_utils,
104105 tests/handlers/test_password_providers.py,
106 tests/rest/client/v1/test_login.py,
105107 tests/rest/client/v2_alpha/test_auth.py,
106108 tests/util/test_stream_change_cache.py
107109
6969
7070 BOOLEAN_COLUMNS = {
7171 "events": ["processed", "outlier", "contains_url"],
72 "rooms": ["is_public"],
72 "rooms": ["is_public", "has_auth_chain_index"],
7373 "event_edges": ["is_state"],
7474 "presence_list": ["accepted"],
7575 "presence_stream": ["currently_active"],
628628 await self._setup_state_group_id_seq()
629629 await self._setup_user_id_seq()
630630 await self._setup_events_stream_seqs()
631 await self._setup_device_inbox_seq()
631632
632633 # Step 3. Get tables.
633634 self.progress.set_state("Fetching tables")
909910 return await self.postgres_store.db_pool.runInteraction(
910911 "_setup_events_stream_seqs", _setup_events_stream_seqs_set_pos,
911912 )
913
914 async def _setup_device_inbox_seq(self):
915 """Set the device inbox sequence to the correct value.
916 """
917 curr_local_id = await self.sqlite_store.db_pool.simple_select_one_onecol(
918 table="device_inbox",
919 keyvalues={},
920 retcol="COALESCE(MAX(stream_id), 1)",
921 allow_none=True,
922 )
923
924 curr_federation_id = await self.sqlite_store.db_pool.simple_select_one_onecol(
925 table="device_federation_outbox",
926 keyvalues={},
927 retcol="COALESCE(MAX(stream_id), 1)",
928 allow_none=True,
929 )
930
931 next_id = max(curr_local_id, curr_federation_id) + 1
932
933 def r(txn):
934 txn.execute(
935 "ALTER SEQUENCE device_inbox_sequence RESTART WITH %s", (next_id,)
936 )
937
938 return self.postgres_store.db_pool.runInteraction("_setup_device_inbox_seq", r)
912939
913940
914941 ##############################################
1414
1515 # Stub for frozendict.
1616
17 from typing import (
18 Any,
19 Hashable,
20 Iterable,
21 Iterator,
22 Mapping,
23 overload,
24 Tuple,
25 TypeVar,
26 )
17 from typing import Any, Hashable, Iterable, Iterator, Mapping, Tuple, TypeVar, overload
2718
2819 _KT = TypeVar("_KT", bound=Hashable) # Key type.
2920 _VT = TypeVar("_VT") # Value type.
66 Callable,
77 Dict,
88 Hashable,
9 ItemsView,
10 Iterable,
911 Iterator,
10 Iterable,
11 ItemsView,
1212 KeysView,
1313 List,
1414 Mapping,
1515 Optional,
1616 Sequence,
17 Tuple,
1718 Type,
1819 TypeVar,
19 Tuple,
2020 Union,
2121 ValuesView,
2222 overload,
1515 """Contains *incomplete* type hints for txredisapi.
1616 """
1717
18 from typing import List, Optional, Union, Type
18 from typing import List, Optional, Type, Union
1919
2020 class RedisProtocol:
2121 def publish(self, channel: str, message: bytes): ...
4747 except ImportError:
4848 pass
4949
50 __version__ = "1.25.0"
50 __version__ = "1.26.0"
5151
5252 if bool(os.environ.get("SYNAPSE_TEST_PATCH_LOG_CONTEXTS", False)):
5353 # We import here so that we don't have to install a bunch of deps when
3232 from synapse.api.room_versions import KNOWN_ROOM_VERSIONS
3333 from synapse.appservice import ApplicationService
3434 from synapse.events import EventBase
35 from synapse.http import get_request_user_agent
3536 from synapse.http.site import SynapseRequest
3637 from synapse.logging import opentracing as opentracing
3738 from synapse.storage.databases.main.registration import TokenLookupResult
185186 AuthError if access is denied for the user in the access token
186187 """
187188 try:
188 ip_addr = self.hs.get_ip_from_request(request)
189 user_agent = request.get_user_agent("")
189 ip_addr = request.getClientIP()
190 user_agent = get_request_user_agent(request)
190191
191192 access_token = self.get_access_token_from_request(request)
192193
274275 return None, None
275276
276277 if app_service.ip_range_whitelist:
277 ip_address = IPAddress(self.hs.get_ip_from_request(request))
278 ip_address = IPAddress(request.getClientIP())
278279 if ip_address not in app_service.ip_range_whitelist:
279280 return None, None
280281
5050 class RoomVersion:
5151 """An object which describes the unique attributes of a room version."""
5252
53 identifier = attr.ib() # str; the identifier for this version
54 disposition = attr.ib() # str; one of the RoomDispositions
55 event_format = attr.ib() # int; one of the EventFormatVersions
56 state_res = attr.ib() # int; one of the StateResolutionVersions
57 enforce_key_validity = attr.ib() # bool
53 identifier = attr.ib(type=str) # the identifier for this version
54 disposition = attr.ib(type=str) # one of the RoomDispositions
55 event_format = attr.ib(type=int) # one of the EventFormatVersions
56 state_res = attr.ib(type=int) # one of the StateResolutionVersions
57 enforce_key_validity = attr.ib(type=bool)
5858
5959 # bool: before MSC2261/MSC2432, m.room.aliases had special auth rules and redaction rules
6060 special_case_aliases_auth = attr.ib(type=bool)
6363 # * Floats
6464 # * NaN, Infinity, -Infinity
6565 strict_canonicaljson = attr.ib(type=bool)
66 # bool: MSC2209: Check 'notifications' key while verifying
66 # MSC2209: Check 'notifications' key while verifying
6767 # m.room.power_levels auth rules.
6868 limit_notifications_power_levels = attr.ib(type=bool)
69 # MSC2174/MSC2176: Apply updated redaction rules algorithm.
70 msc2176_redaction_rules = attr.ib(type=bool)
6971
7072
7173 class RoomVersions:
7880 special_case_aliases_auth=True,
7981 strict_canonicaljson=False,
8082 limit_notifications_power_levels=False,
83 msc2176_redaction_rules=False,
8184 )
8285 V2 = RoomVersion(
8386 "2",
8891 special_case_aliases_auth=True,
8992 strict_canonicaljson=False,
9093 limit_notifications_power_levels=False,
94 msc2176_redaction_rules=False,
9195 )
9296 V3 = RoomVersion(
9397 "3",
98102 special_case_aliases_auth=True,
99103 strict_canonicaljson=False,
100104 limit_notifications_power_levels=False,
105 msc2176_redaction_rules=False,
101106 )
102107 V4 = RoomVersion(
103108 "4",
108113 special_case_aliases_auth=True,
109114 strict_canonicaljson=False,
110115 limit_notifications_power_levels=False,
116 msc2176_redaction_rules=False,
111117 )
112118 V5 = RoomVersion(
113119 "5",
118124 special_case_aliases_auth=True,
119125 strict_canonicaljson=False,
120126 limit_notifications_power_levels=False,
127 msc2176_redaction_rules=False,
121128 )
122129 V6 = RoomVersion(
123130 "6",
128135 special_case_aliases_auth=False,
129136 strict_canonicaljson=True,
130137 limit_notifications_power_levels=True,
138 msc2176_redaction_rules=False,
139 )
140 MSC2176 = RoomVersion(
141 "org.matrix.msc2176",
142 RoomDisposition.UNSTABLE,
143 EventFormatVersions.V3,
144 StateResolutionVersions.V2,
145 enforce_key_validity=True,
146 special_case_aliases_auth=False,
147 strict_canonicaljson=True,
148 limit_notifications_power_levels=True,
149 msc2176_redaction_rules=True,
131150 )
132151
133152
140159 RoomVersions.V4,
141160 RoomVersions.V5,
142161 RoomVersions.V6,
162 RoomVersions.MSC2176,
143163 )
144164 } # type: Dict[str, RoomVersion]
4141 """
4242 if hs_config.form_secret is None:
4343 raise ConfigError("form_secret not set in config")
44 if hs_config.public_baseurl is None:
45 raise ConfigError("public_baseurl not set in config")
4644
4745 self._hmac_secret = hs_config.form_secret.encode("utf-8")
4846 self._public_baseurl = hs_config.public_baseurl
00 # -*- coding: utf-8 -*-
11 # Copyright 2017 New Vector Ltd
2 # Copyright 2019-2021 The Matrix.org Foundation C.I.C
23 #
34 # Licensed under the Apache License, Version 2.0 (the "License");
45 # you may not use this file except in compliance with the License.
1819 import socket
1920 import sys
2021 import traceback
21 from typing import Iterable
22 from typing import Awaitable, Callable, Iterable
2223
2324 from typing_extensions import NoReturn
2425
140141 sys.stderr.write(" %s\n" % (line.rstrip(),))
141142 sys.stderr.write("*" * line_length + "\n")
142143 sys.exit(1)
144
145
146 def register_start(cb: Callable[..., Awaitable], *args, **kwargs) -> None:
147 """Register a callback with the reactor, to be called once it is running
148
149 This can be used to initialise parts of the system which require an asynchronous
150 setup.
151
152 Any exception raised by the callback will be printed and logged, and the process
153 will exit.
154 """
155
156 async def wrapper():
157 try:
158 await cb(*args, **kwargs)
159 except Exception:
160 # previously, we used Failure().printTraceback() here, in the hope that
161 # would give better tracebacks than traceback.print_exc(). However, that
162 # doesn't handle chained exceptions (with a __cause__ or __context__) well,
163 # and I *think* the need for Failure() is reduced now that we mostly use
164 # async/await.
165
166 # Write the exception to both the logs *and* the unredirected stderr,
167 # because people tend to get confused if it only goes to one or the other.
168 #
169 # One problem with this is that if people are using a logging config that
170 # logs to the console (as is common eg under docker), they will get two
171 # copies of the exception. We could maybe try to detect that, but it's
172 # probably a cost we can bear.
173 logger.fatal("Error during startup", exc_info=True)
174 print("Error during startup:", file=sys.__stderr__)
175 traceback.print_exc(file=sys.__stderr__)
176
177 # it's no use calling sys.exit here, since that just raises a SystemExit
178 # exception which is then caught by the reactor, and everything carries
179 # on as normal.
180 os._exit(1)
181
182 reactor.callWhenRunning(lambda: defer.ensureDeferred(wrapper()))
143183
144184
145185 def listen_metrics(bind_addresses, port):
226266 logger.info("Context factories updated.")
227267
228268
229 def start(hs: "synapse.server.HomeServer", listeners: Iterable[ListenerConfig]):
269 async def start(hs: "synapse.server.HomeServer", listeners: Iterable[ListenerConfig]):
230270 """
231271 Start a Synapse server or worker.
232272
240280 hs: homeserver instance
241281 listeners: Listener configuration ('listeners' in homeserver.yaml)
242282 """
243 try:
244 # Set up the SIGHUP machinery.
245 if hasattr(signal, "SIGHUP"):
246
247 reactor = hs.get_reactor()
248
249 @wrap_as_background_process("sighup")
250 def handle_sighup(*args, **kwargs):
251 # Tell systemd our state, if we're using it. This will silently fail if
252 # we're not using systemd.
253 sdnotify(b"RELOADING=1")
254
255 for i, args, kwargs in _sighup_callbacks:
256 i(*args, **kwargs)
257
258 sdnotify(b"READY=1")
259
260 # We defer running the sighup handlers until next reactor tick. This
261 # is so that we're in a sane state, e.g. flushing the logs may fail
262 # if the sighup happens in the middle of writing a log entry.
263 def run_sighup(*args, **kwargs):
264 # `callFromThread` should be "signal safe" as well as thread
265 # safe.
266 reactor.callFromThread(handle_sighup, *args, **kwargs)
267
268 signal.signal(signal.SIGHUP, run_sighup)
269
270 register_sighup(refresh_certificate, hs)
271
272 # Load the certificate from disk.
273 refresh_certificate(hs)
274
275 # Start the tracer
276 synapse.logging.opentracing.init_tracer( # type: ignore[attr-defined] # noqa
277 hs
278 )
279
280 # It is now safe to start your Synapse.
281 hs.start_listening(listeners)
282 hs.get_datastore().db_pool.start_profiling()
283 hs.get_pusherpool().start()
284
285 # Log when we start the shut down process.
286 hs.get_reactor().addSystemEventTrigger(
287 "before", "shutdown", logger.info, "Shutting down..."
288 )
289
290 setup_sentry(hs)
291 setup_sdnotify(hs)
292
293 # If background tasks are running on the main process, start collecting the
294 # phone home stats.
295 if hs.config.run_background_tasks:
296 start_phone_stats_home(hs)
297
298 # We now freeze all allocated objects in the hopes that (almost)
299 # everything currently allocated are things that will be used for the
300 # rest of time. Doing so means less work each GC (hopefully).
301 #
302 # This only works on Python 3.7
303 if sys.version_info >= (3, 7):
304 gc.collect()
305 gc.freeze()
306 except Exception:
307 traceback.print_exc(file=sys.stderr)
283 # Set up the SIGHUP machinery.
284 if hasattr(signal, "SIGHUP"):
308285 reactor = hs.get_reactor()
309 if reactor.running:
310 reactor.stop()
311 sys.exit(1)
286
287 @wrap_as_background_process("sighup")
288 def handle_sighup(*args, **kwargs):
289 # Tell systemd our state, if we're using it. This will silently fail if
290 # we're not using systemd.
291 sdnotify(b"RELOADING=1")
292
293 for i, args, kwargs in _sighup_callbacks:
294 i(*args, **kwargs)
295
296 sdnotify(b"READY=1")
297
298 # We defer running the sighup handlers until next reactor tick. This
299 # is so that we're in a sane state, e.g. flushing the logs may fail
300 # if the sighup happens in the middle of writing a log entry.
301 def run_sighup(*args, **kwargs):
302 # `callFromThread` should be "signal safe" as well as thread
303 # safe.
304 reactor.callFromThread(handle_sighup, *args, **kwargs)
305
306 signal.signal(signal.SIGHUP, run_sighup)
307
308 register_sighup(refresh_certificate, hs)
309
310 # Load the certificate from disk.
311 refresh_certificate(hs)
312
313 # Start the tracer
314 synapse.logging.opentracing.init_tracer( # type: ignore[attr-defined] # noqa
315 hs
316 )
317
318 # It is now safe to start your Synapse.
319 hs.start_listening(listeners)
320 hs.get_datastore().db_pool.start_profiling()
321 hs.get_pusherpool().start()
322
323 # Log when we start the shut down process.
324 hs.get_reactor().addSystemEventTrigger(
325 "before", "shutdown", logger.info, "Shutting down..."
326 )
327
328 setup_sentry(hs)
329 setup_sdnotify(hs)
330
331 # If background tasks are running on the main process, start collecting the
332 # phone home stats.
333 if hs.config.run_background_tasks:
334 start_phone_stats_home(hs)
335
336 # We now freeze all allocated objects in the hopes that (almost)
337 # everything currently allocated are things that will be used for the
338 # rest of time. Doing so means less work each GC (hopefully).
339 #
340 # This only works on Python 3.7
341 if sys.version_info >= (3, 7):
342 gc.collect()
343 gc.freeze()
312344
313345
314346 def setup_sentry(hs):
2020
2121 from typing_extensions import ContextManager
2222
23 from twisted.internet import address, reactor
23 from twisted.internet import address
2424
2525 import synapse
2626 import synapse.events
3333 SERVER_KEY_V2_PREFIX,
3434 )
3535 from synapse.app import _base
36 from synapse.app._base import register_start
3637 from synapse.config._base import ConfigError
3738 from synapse.config.homeserver import HomeServerConfig
3839 from synapse.config.logger import setup_logging
9899 )
99100 from synapse.rest.client.v1.push_rule import PushRuleRestServlet
100101 from synapse.rest.client.v1.voip import VoipRestServlet
101 from synapse.rest.client.v2_alpha import groups, sync, user_directory
102 from synapse.rest.client.v2_alpha import (
103 account_data,
104 groups,
105 read_marker,
106 receipts,
107 room_keys,
108 sync,
109 tags,
110 user_directory,
111 )
102112 from synapse.rest.client.v2_alpha._base import client_patterns
103113 from synapse.rest.client.v2_alpha.account import ThreepidRestServlet
104114 from synapse.rest.client.v2_alpha.account_data import (
105115 AccountDataServlet,
106116 RoomAccountDataServlet,
107117 )
108 from synapse.rest.client.v2_alpha.keys import KeyChangesServlet, KeyQueryServlet
118 from synapse.rest.client.v2_alpha.devices import DevicesRestServlet
119 from synapse.rest.client.v2_alpha.keys import (
120 KeyChangesServlet,
121 KeyQueryServlet,
122 OneTimeKeyServlet,
123 )
109124 from synapse.rest.client.v2_alpha.register import RegisterRestServlet
125 from synapse.rest.client.v2_alpha.sendtodevice import SendToDeviceRestServlet
110126 from synapse.rest.client.versions import VersionsRestServlet
111127 from synapse.rest.health import HealthResource
112128 from synapse.rest.key.v2 import KeyApiV2Resource
113129 from synapse.server import HomeServer, cache_in_self
114130 from synapse.storage.databases.main.censor_events import CensorEventsStore
115131 from synapse.storage.databases.main.client_ips import ClientIpWorkerStore
132 from synapse.storage.databases.main.e2e_room_keys import EndToEndRoomKeyStore
116133 from synapse.storage.databases.main.media_repository import MediaRepositoryStore
117134 from synapse.storage.databases.main.metrics import ServerMetricsStore
118135 from synapse.storage.databases.main.monthly_active_users import (
444461 UserDirectoryStore,
445462 StatsStore,
446463 UIAuthWorkerStore,
464 EndToEndRoomKeyStore,
447465 SlavedDeviceInboxStore,
448466 SlavedDeviceStore,
449467 SlavedReceiptsStore,
500518 RegisterRestServlet(self).register(resource)
501519 LoginRestServlet(self).register(resource)
502520 ThreepidRestServlet(self).register(resource)
521 DevicesRestServlet(self).register(resource)
503522 KeyQueryServlet(self).register(resource)
523 OneTimeKeyServlet(self).register(resource)
504524 KeyChangesServlet(self).register(resource)
505525 VoipRestServlet(self).register(resource)
506526 PushRuleRestServlet(self).register(resource)
518538 room.register_servlets(self, resource, True)
519539 room.register_deprecated_servlets(self, resource)
520540 InitialSyncRestServlet(self).register(resource)
541 room_keys.register_servlets(self, resource)
542 tags.register_servlets(self, resource)
543 account_data.register_servlets(self, resource)
544 receipts.register_servlets(self, resource)
545 read_marker.register_servlets(self, resource)
546
547 SendToDeviceRestServlet(self).register(resource)
521548
522549 user_directory.register_servlets(self, resource)
523550
956983 # streams. Will no-op if no streams can be written to by this worker.
957984 hs.get_replication_streamer()
958985
959 reactor.addSystemEventTrigger(
960 "before", "startup", _base.start, hs, config.worker_listeners
961 )
986 register_start(_base.start, hs, config.worker_listeners)
962987
963988 _base.start_worker_reactor("synapse-generic-worker", config)
964989
1414 # See the License for the specific language governing permissions and
1515 # limitations under the License.
1616
17 import gc
1817 import logging
1918 import os
2019 import sys
2120 from typing import Iterable, Iterator
2221
23 from twisted.application import service
24 from twisted.internet import defer, reactor
25 from twisted.python.failure import Failure
22 from twisted.internet import reactor
2623 from twisted.web.resource import EncodingResourceWrapper, IResource
2724 from twisted.web.server import GzipEncoderFactory
2825 from twisted.web.static import File
3936 WEB_CLIENT_PREFIX,
4037 )
4138 from synapse.app import _base
42 from synapse.app._base import listen_ssl, listen_tcp, quit_with_error
39 from synapse.app._base import listen_ssl, listen_tcp, quit_with_error, register_start
4340 from synapse.config._base import ConfigError
4441 from synapse.config.emailconfig import ThreepidBehaviour
4542 from synapse.config.homeserver import HomeServerConfig
6259 from synapse.rest.admin import AdminRestResource
6360 from synapse.rest.health import HealthResource
6461 from synapse.rest.key.v2 import KeyApiV2Resource
62 from synapse.rest.synapse.client.pick_idp import PickIdpResource
6563 from synapse.rest.synapse.client.pick_username import pick_username_resource
6664 from synapse.rest.well_known import WellKnownResource
6765 from synapse.server import HomeServer
7169 from synapse.util.httpresourcetree import create_resource_tree
7270 from synapse.util.manhole import manhole
7371 from synapse.util.module_loader import load_module
74 from synapse.util.rlimit import change_resource_limit
7572 from synapse.util.versionstring import get_version_string
7673
7774 logger = logging.getLogger("synapse.app.homeserver")
193190 "/.well-known/matrix/client": WellKnownResource(self),
194191 "/_synapse/admin": AdminRestResource(self),
195192 "/_synapse/client/pick_username": pick_username_resource(self),
193 "/_synapse/client/pick_idp": PickIdpResource(self),
196194 }
197195 )
198196
414412 _base.refresh_certificate(hs)
415413
416414 async def start():
417 try:
418 # Run the ACME provisioning code, if it's enabled.
419 if hs.config.acme_enabled:
420 acme = hs.get_acme_handler()
421 # Start up the webservices which we will respond to ACME
422 # challenges with, and then provision.
423 await acme.start_listening()
424 await do_acme()
425
426 # Check if it needs to be reprovisioned every day.
427 hs.get_clock().looping_call(reprovision_acme, 24 * 60 * 60 * 1000)
428
429 # Load the OIDC provider metadatas, if OIDC is enabled.
430 if hs.config.oidc_enabled:
431 oidc = hs.get_oidc_handler()
432 # Loading the provider metadata also ensures the provider config is valid.
433 await oidc.load_metadata()
434 await oidc.load_jwks()
435
436 _base.start(hs, config.listeners)
437
438 hs.get_datastore().db_pool.updates.start_doing_background_updates()
439 except Exception:
440 # Print the exception and bail out.
441 print("Error during startup:", file=sys.stderr)
442
443 # this gives better tracebacks than traceback.print_exc()
444 Failure().printTraceback(file=sys.stderr)
445
446 if reactor.running:
447 reactor.stop()
448 sys.exit(1)
449
450 reactor.callWhenRunning(lambda: defer.ensureDeferred(start()))
415 # Run the ACME provisioning code, if it's enabled.
416 if hs.config.acme_enabled:
417 acme = hs.get_acme_handler()
418 # Start up the webservices which we will respond to ACME
419 # challenges with, and then provision.
420 await acme.start_listening()
421 await do_acme()
422
423 # Check if it needs to be reprovisioned every day.
424 hs.get_clock().looping_call(reprovision_acme, 24 * 60 * 60 * 1000)
425
426 # Load the OIDC provider metadatas, if OIDC is enabled.
427 if hs.config.oidc_enabled:
428 oidc = hs.get_oidc_handler()
429 # Loading the provider metadata also ensures the provider config is valid.
430 await oidc.load_metadata()
431
432 await _base.start(hs, config.listeners)
433
434 hs.get_datastore().db_pool.updates.start_doing_background_updates()
435
436 register_start(start)
451437
452438 return hs
453439
482468 indent += 1
483469 yield ":\n%s%s" % (" " * indent, str(e))
484470 e = e.__cause__
485
486
487 class SynapseService(service.Service):
488 """
489 A twisted Service class that will start synapse. Used to run synapse
490 via twistd and a .tac.
491 """
492
493 def __init__(self, config):
494 self.config = config
495
496 def startService(self):
497 hs = setup(self.config)
498 change_resource_limit(hs.config.soft_file_limit)
499 if hs.config.gc_thresholds:
500 gc.set_threshold(*hs.config.gc_thresholds)
501
502 def stopService(self):
503 return self._port.stopListening()
504471
505472
506473 def run(hs):
251251 env = jinja2.Environment(loader=loader, autoescape=autoescape)
252252
253253 # Update the environment with our custom filters
254 env.filters.update({"format_ts": _format_ts_filter})
255 if self.public_baseurl:
256 env.filters.update(
257 {"mxc_to_http": _create_mxc_to_http_filter(self.public_baseurl)}
258 )
254 env.filters.update(
255 {
256 "format_ts": _format_ts_filter,
257 "mxc_to_http": _create_mxc_to_http_filter(self.public_baseurl),
258 }
259 )
259260
260261 for filename in filenames:
261262 # Load the template
5555 """
5656 # copy `config_path` before modifying it.
5757 path = list(config_path)
58 for p in list(e.path):
58 for p in list(e.absolute_path):
5959 if isinstance(p, int):
6060 path.append("<item %i>" % p)
6161 else:
3939 self.cas_required_attributes = {}
4040
4141 def generate_config_section(self, config_dir_path, server_name, **kwargs):
42 return """
42 return """\
4343 # Enable Central Authentication Service (CAS) for registration and login.
4444 #
4545 cas_config:
164164 missing = []
165165 if not self.email_notif_from:
166166 missing.append("email.notif_from")
167
168 # public_baseurl is required to build password reset and validation links that
169 # will be emailed to users
170 if config.get("public_baseurl") is None:
171 missing.append("public_baseurl")
172167
173168 if missing:
174169 raise ConfigError(
268263 if not self.email_notif_from:
269264 missing.append("email.notif_from")
270265
271 if config.get("public_baseurl") is None:
272 missing.append("public_baseurl")
273
274266 if missing:
275267 raise ConfigError(
276268 "email.enable_notifs is True but required keys are missing: %s"
00 # -*- coding: utf-8 -*-
11 # Copyright 2020 Quentin Gliech
2 # Copyright 2020-2021 The Matrix.org Foundation C.I.C.
23 #
34 # Licensed under the Apache License, Version 2.0 (the "License");
45 # you may not use this file except in compliance with the License.
1213 # See the License for the specific language governing permissions and
1314 # limitations under the License.
1415
16 import string
17 from typing import Iterable, Optional, Tuple, Type
18
19 import attr
20
21 from synapse.config._util import validate_config
1522 from synapse.python_dependencies import DependencyException, check_requirements
23 from synapse.types import Collection, JsonDict
1624 from synapse.util.module_loader import load_module
25 from synapse.util.stringutils import parse_and_validate_mxc_uri
1726
1827 from ._base import Config, ConfigError
1928
2433 section = "oidc"
2534
2635 def read_config(self, config, **kwargs):
27 self.oidc_enabled = False
28
29 oidc_config = config.get("oidc_config")
30
31 if not oidc_config or not oidc_config.get("enabled", False):
36 self.oidc_providers = tuple(_parse_oidc_provider_configs(config))
37 if not self.oidc_providers:
3238 return
3339
3440 try:
3541 check_requirements("oidc")
3642 except DependencyException as e:
37 raise ConfigError(e.message)
43 raise ConfigError(e.message) from e
3844
3945 public_baseurl = self.public_baseurl
40 if public_baseurl is None:
41 raise ConfigError("oidc_config requires a public_baseurl to be set")
4246 self.oidc_callback_url = public_baseurl + "_synapse/oidc/callback"
4347
44 self.oidc_enabled = True
45 self.oidc_discover = oidc_config.get("discover", True)
46 self.oidc_issuer = oidc_config["issuer"]
47 self.oidc_client_id = oidc_config["client_id"]
48 self.oidc_client_secret = oidc_config["client_secret"]
49 self.oidc_client_auth_method = oidc_config.get(
50 "client_auth_method", "client_secret_basic"
51 )
52 self.oidc_scopes = oidc_config.get("scopes", ["openid"])
53 self.oidc_authorization_endpoint = oidc_config.get("authorization_endpoint")
54 self.oidc_token_endpoint = oidc_config.get("token_endpoint")
55 self.oidc_userinfo_endpoint = oidc_config.get("userinfo_endpoint")
56 self.oidc_jwks_uri = oidc_config.get("jwks_uri")
57 self.oidc_skip_verification = oidc_config.get("skip_verification", False)
58 self.oidc_user_profile_method = oidc_config.get("user_profile_method", "auto")
59 self.oidc_allow_existing_users = oidc_config.get("allow_existing_users", False)
60
61 ump_config = oidc_config.get("user_mapping_provider", {})
62 ump_config.setdefault("module", DEFAULT_USER_MAPPING_PROVIDER)
63 ump_config.setdefault("config", {})
64
65 (
66 self.oidc_user_mapping_provider_class,
67 self.oidc_user_mapping_provider_config,
68 ) = load_module(ump_config, ("oidc_config", "user_mapping_provider"))
69
70 # Ensure loaded user mapping module has defined all necessary methods
71 required_methods = [
72 "get_remote_user_id",
73 "map_user_attributes",
74 ]
75 missing_methods = [
76 method
77 for method in required_methods
78 if not hasattr(self.oidc_user_mapping_provider_class, method)
79 ]
80 if missing_methods:
81 raise ConfigError(
82 "Class specified by oidc_config."
83 "user_mapping_provider.module is missing required "
84 "methods: %s" % (", ".join(missing_methods),)
85 )
48 @property
49 def oidc_enabled(self) -> bool:
50 # OIDC is enabled if we have a provider
51 return bool(self.oidc_providers)
8652
8753 def generate_config_section(self, config_dir_path, server_name, **kwargs):
8854 return """\
89 # Enable OpenID Connect (OIDC) / OAuth 2.0 for registration and login.
55 # List of OpenID Connect (OIDC) / OAuth 2.0 identity providers, for registration
56 # and login.
57 #
58 # Options for each entry include:
59 #
60 # idp_id: a unique identifier for this identity provider. Used internally
61 # by Synapse; should be a single word such as 'github'.
62 #
63 # Note that, if this is changed, users authenticating via that provider
64 # will no longer be recognised as the same user!
65 #
66 # idp_name: A user-facing name for this identity provider, which is used to
67 # offer the user a choice of login mechanisms.
68 #
69 # idp_icon: An optional icon for this identity provider, which is presented
70 # by identity picker pages. If given, must be an MXC URI of the format
71 # mxc://<server-name>/<media-id>. (An easy way to obtain such an MXC URI
72 # is to upload an image to an (unencrypted) room and then copy the "url"
73 # from the source of the event.)
74 #
75 # discover: set to 'false' to disable the use of the OIDC discovery mechanism
76 # to discover endpoints. Defaults to true.
77 #
78 # issuer: Required. The OIDC issuer. Used to validate tokens and (if discovery
79 # is enabled) to discover the provider's endpoints.
80 #
81 # client_id: Required. oauth2 client id to use.
82 #
83 # client_secret: Required. oauth2 client secret to use.
84 #
85 # client_auth_method: auth method to use when exchanging the token. Valid
86 # values are 'client_secret_basic' (default), 'client_secret_post' and
87 # 'none'.
88 #
89 # scopes: list of scopes to request. This should normally include the "openid"
90 # scope. Defaults to ["openid"].
91 #
92 # authorization_endpoint: the oauth2 authorization endpoint. Required if
93 # provider discovery is disabled.
94 #
95 # token_endpoint: the oauth2 token endpoint. Required if provider discovery is
96 # disabled.
97 #
98 # userinfo_endpoint: the OIDC userinfo endpoint. Required if discovery is
99 # disabled and the 'openid' scope is not requested.
100 #
101 # jwks_uri: URI where to fetch the JWKS. Required if discovery is disabled and
102 # the 'openid' scope is used.
103 #
104 # skip_verification: set to 'true' to skip metadata verification. Use this if
105 # you are connecting to a provider that is not OpenID Connect compliant.
106 # Defaults to false. Avoid this in production.
107 #
108 # user_profile_method: Whether to fetch the user profile from the userinfo
109 # endpoint. Valid values are: 'auto' or 'userinfo_endpoint'.
110 #
111 # Defaults to 'auto', which fetches the userinfo endpoint if 'openid' is
112 # included in 'scopes'. Set to 'userinfo_endpoint' to always fetch the
113 # userinfo endpoint.
114 #
115 # allow_existing_users: set to 'true' to allow a user logging in via OIDC to
116 # match a pre-existing account instead of failing. This could be used if
117 # switching from password logins to OIDC. Defaults to false.
118 #
119 # user_mapping_provider: Configuration for how attributes returned from a OIDC
120 # provider are mapped onto a matrix user. This setting has the following
121 # sub-properties:
122 #
123 # module: The class name of a custom mapping module. Default is
124 # {mapping_provider!r}.
125 # See https://github.com/matrix-org/synapse/blob/master/docs/sso_mapping_providers.md#openid-mapping-providers
126 # for information on implementing a custom mapping provider.
127 #
128 # config: Configuration for the mapping provider module. This section will
129 # be passed as a Python dictionary to the user mapping provider
130 # module's `parse_config` method.
131 #
132 # For the default provider, the following settings are available:
133 #
134 # sub: name of the claim containing a unique identifier for the
135 # user. Defaults to 'sub', which OpenID Connect compliant
136 # providers should provide.
137 #
138 # localpart_template: Jinja2 template for the localpart of the MXID.
139 # If this is not set, the user will be prompted to choose their
140 # own username.
141 #
142 # display_name_template: Jinja2 template for the display name to set
143 # on first login. If unset, no displayname will be set.
144 #
145 # extra_attributes: a map of Jinja2 templates for extra attributes
146 # to send back to the client during login.
147 # Note that these are non-standard and clients will ignore them
148 # without modifications.
149 #
150 # When rendering, the Jinja2 templates are given a 'user' variable,
151 # which is set to the claims returned by the UserInfo Endpoint and/or
152 # in the ID Token.
90153 #
91154 # See https://github.com/matrix-org/synapse/blob/master/docs/openid.md
92 # for some example configurations.
93 #
94 oidc_config:
95 # Uncomment the following to enable authorization against an OpenID Connect
96 # server. Defaults to false.
155 # for information on how to configure these options.
156 #
157 # For backwards compatibility, it is also possible to configure a single OIDC
158 # provider via an 'oidc_config' setting. This is now deprecated and admins are
159 # advised to migrate to the 'oidc_providers' format. (When doing that migration,
160 # use 'oidc' for the idp_id to ensure that existing users continue to be
161 # recognised.)
162 #
163 oidc_providers:
164 # Generic example
97165 #
98 #enabled: true
99
100 # Uncomment the following to disable use of the OIDC discovery mechanism to
101 # discover endpoints. Defaults to true.
166 #- idp_id: my_idp
167 # idp_name: "My OpenID provider"
168 # idp_icon: "mxc://example.com/mediaid"
169 # discover: false
170 # issuer: "https://accounts.example.com/"
171 # client_id: "provided-by-your-issuer"
172 # client_secret: "provided-by-your-issuer"
173 # client_auth_method: client_secret_post
174 # scopes: ["openid", "profile"]
175 # authorization_endpoint: "https://accounts.example.com/oauth2/auth"
176 # token_endpoint: "https://accounts.example.com/oauth2/token"
177 # userinfo_endpoint: "https://accounts.example.com/userinfo"
178 # jwks_uri: "https://accounts.example.com/.well-known/jwks.json"
179 # skip_verification: true
180
181 # For use with Keycloak
102182 #
103 #discover: false
104
105 # the OIDC issuer. Used to validate tokens and (if discovery is enabled) to
106 # discover the provider's endpoints.
183 #- idp_id: keycloak
184 # idp_name: Keycloak
185 # issuer: "https://127.0.0.1:8443/auth/realms/my_realm_name"
186 # client_id: "synapse"
187 # client_secret: "copy secret generated in Keycloak UI"
188 # scopes: ["openid", "profile"]
189
190 # For use with Github
107191 #
108 # Required if 'enabled' is true.
109 #
110 #issuer: "https://accounts.example.com/"
111
112 # oauth2 client id to use.
113 #
114 # Required if 'enabled' is true.
115 #
116 #client_id: "provided-by-your-issuer"
117
118 # oauth2 client secret to use.
119 #
120 # Required if 'enabled' is true.
121 #
122 #client_secret: "provided-by-your-issuer"
123
124 # auth method to use when exchanging the token.
125 # Valid values are 'client_secret_basic' (default), 'client_secret_post' and
126 # 'none'.
127 #
128 #client_auth_method: client_secret_post
129
130 # list of scopes to request. This should normally include the "openid" scope.
131 # Defaults to ["openid"].
132 #
133 #scopes: ["openid", "profile"]
134
135 # the oauth2 authorization endpoint. Required if provider discovery is disabled.
136 #
137 #authorization_endpoint: "https://accounts.example.com/oauth2/auth"
138
139 # the oauth2 token endpoint. Required if provider discovery is disabled.
140 #
141 #token_endpoint: "https://accounts.example.com/oauth2/token"
142
143 # the OIDC userinfo endpoint. Required if discovery is disabled and the
144 # "openid" scope is not requested.
145 #
146 #userinfo_endpoint: "https://accounts.example.com/userinfo"
147
148 # URI where to fetch the JWKS. Required if discovery is disabled and the
149 # "openid" scope is used.
150 #
151 #jwks_uri: "https://accounts.example.com/.well-known/jwks.json"
152
153 # Uncomment to skip metadata verification. Defaults to false.
154 #
155 # Use this if you are connecting to a provider that is not OpenID Connect
156 # compliant.
157 # Avoid this in production.
158 #
159 #skip_verification: true
160
161 # Whether to fetch the user profile from the userinfo endpoint. Valid
162 # values are: "auto" or "userinfo_endpoint".
163 #
164 # Defaults to "auto", which fetches the userinfo endpoint if "openid" is included
165 # in `scopes`. Uncomment the following to always fetch the userinfo endpoint.
166 #
167 #user_profile_method: "userinfo_endpoint"
168
169 # Uncomment to allow a user logging in via OIDC to match a pre-existing account instead
170 # of failing. This could be used if switching from password logins to OIDC. Defaults to false.
171 #
172 #allow_existing_users: true
173
174 # An external module can be provided here as a custom solution to mapping
175 # attributes returned from a OIDC provider onto a matrix user.
176 #
177 user_mapping_provider:
178 # The custom module's class. Uncomment to use a custom module.
179 # Default is {mapping_provider!r}.
180 #
181 # See https://github.com/matrix-org/synapse/blob/master/docs/sso_mapping_providers.md#openid-mapping-providers
182 # for information on implementing a custom mapping provider.
183 #
184 #module: mapping_provider.OidcMappingProvider
185
186 # Custom configuration values for the module. This section will be passed as
187 # a Python dictionary to the user mapping provider module's `parse_config`
188 # method.
189 #
190 # The examples below are intended for the default provider: they should be
191 # changed if using a custom provider.
192 #
193 config:
194 # name of the claim containing a unique identifier for the user.
195 # Defaults to `sub`, which OpenID Connect compliant providers should provide.
196 #
197 #subject_claim: "sub"
198
199 # Jinja2 template for the localpart of the MXID.
200 #
201 # When rendering, this template is given the following variables:
202 # * user: The claims returned by the UserInfo Endpoint and/or in the ID
203 # Token
204 #
205 # If this is not set, the user will be prompted to choose their
206 # own username.
207 #
208 #localpart_template: "{{{{ user.preferred_username }}}}"
209
210 # Jinja2 template for the display name to set on first login.
211 #
212 # If unset, no displayname will be set.
213 #
214 #display_name_template: "{{{{ user.given_name }}}} {{{{ user.last_name }}}}"
215
216 # Jinja2 templates for extra attributes to send back to the client during
217 # login.
218 #
219 # Note that these are non-standard and clients will ignore them without modifications.
220 #
221 #extra_attributes:
222 #birthdate: "{{{{ user.birthdate }}}}"
192 #- idp_id: github
193 # idp_name: Github
194 # discover: false
195 # issuer: "https://github.com/"
196 # client_id: "your-client-id" # TO BE FILLED
197 # client_secret: "your-client-secret" # TO BE FILLED
198 # authorization_endpoint: "https://github.com/login/oauth/authorize"
199 # token_endpoint: "https://github.com/login/oauth/access_token"
200 # userinfo_endpoint: "https://api.github.com/user"
201 # scopes: ["read:user"]
202 # user_mapping_provider:
203 # config:
204 # subject_claim: "id"
205 # localpart_template: "{{ user.login }}"
206 # display_name_template: "{{ user.name }}"
223207 """.format(
224208 mapping_provider=DEFAULT_USER_MAPPING_PROVIDER
225209 )
210
211
212 # jsonschema definition of the configuration settings for an oidc identity provider
213 OIDC_PROVIDER_CONFIG_SCHEMA = {
214 "type": "object",
215 "required": ["issuer", "client_id", "client_secret"],
216 "properties": {
217 # TODO: fix the maxLength here depending on what MSC2528 decides
218 # remember that we prefix the ID given here with `oidc-`
219 "idp_id": {"type": "string", "minLength": 1, "maxLength": 128},
220 "idp_name": {"type": "string"},
221 "idp_icon": {"type": "string"},
222 "discover": {"type": "boolean"},
223 "issuer": {"type": "string"},
224 "client_id": {"type": "string"},
225 "client_secret": {"type": "string"},
226 "client_auth_method": {
227 "type": "string",
228 # the following list is the same as the keys of
229 # authlib.oauth2.auth.ClientAuth.DEFAULT_AUTH_METHODS. We inline it
230 # to avoid importing authlib here.
231 "enum": ["client_secret_basic", "client_secret_post", "none"],
232 },
233 "scopes": {"type": "array", "items": {"type": "string"}},
234 "authorization_endpoint": {"type": "string"},
235 "token_endpoint": {"type": "string"},
236 "userinfo_endpoint": {"type": "string"},
237 "jwks_uri": {"type": "string"},
238 "skip_verification": {"type": "boolean"},
239 "user_profile_method": {
240 "type": "string",
241 "enum": ["auto", "userinfo_endpoint"],
242 },
243 "allow_existing_users": {"type": "boolean"},
244 "user_mapping_provider": {"type": ["object", "null"]},
245 },
246 }
247
248 # the same as OIDC_PROVIDER_CONFIG_SCHEMA, but with compulsory idp_id and idp_name
249 OIDC_PROVIDER_CONFIG_WITH_ID_SCHEMA = {
250 "allOf": [OIDC_PROVIDER_CONFIG_SCHEMA, {"required": ["idp_id", "idp_name"]}]
251 }
252
253
254 # the `oidc_providers` list can either be None (as it is in the default config), or
255 # a list of provider configs, each of which requires an explicit ID and name.
256 OIDC_PROVIDER_LIST_SCHEMA = {
257 "oneOf": [
258 {"type": "null"},
259 {"type": "array", "items": OIDC_PROVIDER_CONFIG_WITH_ID_SCHEMA},
260 ]
261 }
262
263 # the `oidc_config` setting can either be None (which it used to be in the default
264 # config), or an object. If an object, it is ignored unless it has an "enabled: True"
265 # property.
266 #
267 # It's *possible* to represent this with jsonschema, but the resultant errors aren't
268 # particularly clear, so we just check for either an object or a null here, and do
269 # additional checks in the code.
270 OIDC_CONFIG_SCHEMA = {"oneOf": [{"type": "null"}, {"type": "object"}]}
271
272 # the top-level schema can contain an "oidc_config" and/or an "oidc_providers".
273 MAIN_CONFIG_SCHEMA = {
274 "type": "object",
275 "properties": {
276 "oidc_config": OIDC_CONFIG_SCHEMA,
277 "oidc_providers": OIDC_PROVIDER_LIST_SCHEMA,
278 },
279 }
280
281
282 def _parse_oidc_provider_configs(config: JsonDict) -> Iterable["OidcProviderConfig"]:
283 """extract and parse the OIDC provider configs from the config dict
284
285 The configuration may contain either a single `oidc_config` object with an
286 `enabled: True` property, or a list of provider configurations under
287 `oidc_providers`, *or both*.
288
289 Returns a generator which yields the OidcProviderConfig objects
290 """
291 validate_config(MAIN_CONFIG_SCHEMA, config, ())
292
293 for i, p in enumerate(config.get("oidc_providers") or []):
294 yield _parse_oidc_config_dict(p, ("oidc_providers", "<item %i>" % (i,)))
295
296 # for backwards-compatibility, it is also possible to provide a single "oidc_config"
297 # object with an "enabled: True" property.
298 oidc_config = config.get("oidc_config")
299 if oidc_config and oidc_config.get("enabled", False):
300 # MAIN_CONFIG_SCHEMA checks that `oidc_config` is an object, but not that
301 # it matches OIDC_PROVIDER_CONFIG_SCHEMA (see the comments on OIDC_CONFIG_SCHEMA
302 # above), so now we need to validate it.
303 validate_config(OIDC_PROVIDER_CONFIG_SCHEMA, oidc_config, ("oidc_config",))
304 yield _parse_oidc_config_dict(oidc_config, ("oidc_config",))
305
306
307 def _parse_oidc_config_dict(
308 oidc_config: JsonDict, config_path: Tuple[str, ...]
309 ) -> "OidcProviderConfig":
310 """Take the configuration dict and parse it into an OidcProviderConfig
311
312 Raises:
313 ConfigError if the configuration is malformed.
314 """
315 ump_config = oidc_config.get("user_mapping_provider", {})
316 ump_config.setdefault("module", DEFAULT_USER_MAPPING_PROVIDER)
317 ump_config.setdefault("config", {})
318
319 (user_mapping_provider_class, user_mapping_provider_config,) = load_module(
320 ump_config, config_path + ("user_mapping_provider",)
321 )
322
323 # Ensure loaded user mapping module has defined all necessary methods
324 required_methods = [
325 "get_remote_user_id",
326 "map_user_attributes",
327 ]
328 missing_methods = [
329 method
330 for method in required_methods
331 if not hasattr(user_mapping_provider_class, method)
332 ]
333 if missing_methods:
334 raise ConfigError(
335 "Class %s is missing required "
336 "methods: %s" % (user_mapping_provider_class, ", ".join(missing_methods),),
337 config_path + ("user_mapping_provider", "module"),
338 )
339
340 # MSC2858 will apply certain limits in what can be used as an IdP id, so let's
341 # enforce those limits now.
342 # TODO: factor out this stuff to a generic function
343 idp_id = oidc_config.get("idp_id", "oidc")
344
345 # TODO: update this validity check based on what MSC2858 decides.
346 valid_idp_chars = set(string.ascii_lowercase + string.digits + "-._")
347
348 if any(c not in valid_idp_chars for c in idp_id):
349 raise ConfigError(
350 'idp_id may only contain a-z, 0-9, "-", ".", "_"',
351 config_path + ("idp_id",),
352 )
353
354 if idp_id[0] not in string.ascii_lowercase:
355 raise ConfigError(
356 "idp_id must start with a-z", config_path + ("idp_id",),
357 )
358
359 # prefix the given IDP with a prefix specific to the SSO mechanism, to avoid
360 # clashes with other mechs (such as SAML, CAS).
361 #
362 # We allow "oidc" as an exception so that people migrating from old-style
363 # "oidc_config" format (which has long used "oidc" as its idp_id) can migrate to
364 # a new-style "oidc_providers" entry without changing the idp_id for their provider
365 # (and thereby invalidating their user_external_ids data).
366
367 if idp_id != "oidc":
368 idp_id = "oidc-" + idp_id
369
370 # MSC2858 also specifies that the idp_icon must be a valid MXC uri
371 idp_icon = oidc_config.get("idp_icon")
372 if idp_icon is not None:
373 try:
374 parse_and_validate_mxc_uri(idp_icon)
375 except ValueError as e:
376 raise ConfigError(
377 "idp_icon must be a valid MXC URI", config_path + ("idp_icon",)
378 ) from e
379
380 return OidcProviderConfig(
381 idp_id=idp_id,
382 idp_name=oidc_config.get("idp_name", "OIDC"),
383 idp_icon=idp_icon,
384 discover=oidc_config.get("discover", True),
385 issuer=oidc_config["issuer"],
386 client_id=oidc_config["client_id"],
387 client_secret=oidc_config["client_secret"],
388 client_auth_method=oidc_config.get("client_auth_method", "client_secret_basic"),
389 scopes=oidc_config.get("scopes", ["openid"]),
390 authorization_endpoint=oidc_config.get("authorization_endpoint"),
391 token_endpoint=oidc_config.get("token_endpoint"),
392 userinfo_endpoint=oidc_config.get("userinfo_endpoint"),
393 jwks_uri=oidc_config.get("jwks_uri"),
394 skip_verification=oidc_config.get("skip_verification", False),
395 user_profile_method=oidc_config.get("user_profile_method", "auto"),
396 allow_existing_users=oidc_config.get("allow_existing_users", False),
397 user_mapping_provider_class=user_mapping_provider_class,
398 user_mapping_provider_config=user_mapping_provider_config,
399 )
400
401
402 @attr.s(slots=True, frozen=True)
403 class OidcProviderConfig:
404 # a unique identifier for this identity provider. Used in the 'user_external_ids'
405 # table, as well as the query/path parameter used in the login protocol.
406 idp_id = attr.ib(type=str)
407
408 # user-facing name for this identity provider.
409 idp_name = attr.ib(type=str)
410
411 # Optional MXC URI for icon for this IdP.
412 idp_icon = attr.ib(type=Optional[str])
413
414 # whether the OIDC discovery mechanism is used to discover endpoints
415 discover = attr.ib(type=bool)
416
417 # the OIDC issuer. Used to validate tokens and (if discovery is enabled) to
418 # discover the provider's endpoints.
419 issuer = attr.ib(type=str)
420
421 # oauth2 client id to use
422 client_id = attr.ib(type=str)
423
424 # oauth2 client secret to use
425 client_secret = attr.ib(type=str)
426
427 # auth method to use when exchanging the token.
428 # Valid values are 'client_secret_basic', 'client_secret_post' and
429 # 'none'.
430 client_auth_method = attr.ib(type=str)
431
432 # list of scopes to request
433 scopes = attr.ib(type=Collection[str])
434
435 # the oauth2 authorization endpoint. Required if discovery is disabled.
436 authorization_endpoint = attr.ib(type=Optional[str])
437
438 # the oauth2 token endpoint. Required if discovery is disabled.
439 token_endpoint = attr.ib(type=Optional[str])
440
441 # the OIDC userinfo endpoint. Required if discovery is disabled and the
442 # "openid" scope is not requested.
443 userinfo_endpoint = attr.ib(type=Optional[str])
444
445 # URI where to fetch the JWKS. Required if discovery is disabled and the
446 # "openid" scope is used.
447 jwks_uri = attr.ib(type=Optional[str])
448
449 # Whether to skip metadata verification
450 skip_verification = attr.ib(type=bool)
451
452 # Whether to fetch the user profile from the userinfo endpoint. Valid
453 # values are: "auto" or "userinfo_endpoint".
454 user_profile_method = attr.ib(type=str)
455
456 # whether to allow a user logging in via OIDC to match a pre-existing account
457 # instead of failing
458 allow_existing_users = attr.ib(type=bool)
459
460 # the class of the user mapping provider
461 user_mapping_provider_class = attr.ib(type=Type)
462
463 # the config of the user mapping provider
464 user_mapping_provider_config = attr.ib()
1313 # limitations under the License.
1414
1515 import os
16 from distutils.util import strtobool
1716
1817 import pkg_resources
1918
2019 from synapse.api.constants import RoomCreationPreset
2120 from synapse.config._base import Config, ConfigError
2221 from synapse.types import RoomAlias, UserID
23 from synapse.util.stringutils import random_string_with_symbols
22 from synapse.util.stringutils import random_string_with_symbols, strtobool
2423
2524
2625 class AccountValidityConfig(Config):
4948
5049 self.startup_job_max_delta = self.period * 10.0 / 100.0
5150
52 if self.renew_by_email_enabled:
53 if "public_baseurl" not in synapse_config:
54 raise ConfigError("Can't send renewal emails without 'public_baseurl'")
55
5651 template_dir = config.get("template_dir")
5752
5853 if not template_dir:
8580 section = "registration"
8681
8782 def read_config(self, config, **kwargs):
88 self.enable_registration = bool(
89 strtobool(str(config.get("enable_registration", False)))
83 self.enable_registration = strtobool(
84 str(config.get("enable_registration", False))
9085 )
9186 if "disable_registration" in config:
92 self.enable_registration = not bool(
93 strtobool(str(config["disable_registration"]))
87 self.enable_registration = not strtobool(
88 str(config["disable_registration"])
9489 )
9590
9691 self.account_validity = AccountValidityConfig(
109104 account_threepid_delegates = config.get("account_threepid_delegates") or {}
110105 self.account_threepid_delegate_email = account_threepid_delegates.get("email")
111106 self.account_threepid_delegate_msisdn = account_threepid_delegates.get("msisdn")
112 if self.account_threepid_delegate_msisdn and not self.public_baseurl:
113 raise ConfigError(
114 "The configuration option `public_baseurl` is required if "
115 "`account_threepid_delegate.msisdn` is set, such that "
116 "clients know where to submit validation tokens to. Please "
117 "configure `public_baseurl`."
118 )
119107
120108 self.default_identity_server = config.get("default_identity_server")
121109 self.allow_guest_access = config.get("allow_guest_access", False)
240228 # send an email to the account's email address with a renewal link. By
241229 # default, no such emails are sent.
242230 #
243 # If you enable this setting, you will also need to fill out the 'email' and
244 # 'public_baseurl' configuration sections.
231 # If you enable this setting, you will also need to fill out the 'email'
232 # configuration section. You should also check that 'public_baseurl' is set
233 # correctly.
245234 #
246235 #renew_at: 1w
247236
332321 # The identity server which we suggest that clients should use when users log
333322 # in on this server.
334323 #
335 # (By default, no suggestion is made, so it is left up to the client.
336 # This setting is ignored unless public_baseurl is also set.)
324 # (By default, no suggestion is made, so it is left up to the client.)
337325 #
338326 #default_identity_server: https://matrix.org
339327
357345 # Servers handling the these requests must answer the `/requestToken` endpoints defined
358346 # by the Matrix Identity Service API specification:
359347 # https://matrix.org/docs/spec/identity_service/latest
360 #
361 # If a delegate is specified, the config option public_baseurl must also be filled out.
362348 #
363349 account_threepid_delegates:
364350 #email: https://example.com # Delegate email sending to example.com
188188 import saml2
189189
190190 public_baseurl = self.public_baseurl
191 if public_baseurl is None:
192 raise ConfigError("saml2_config requires a public_baseurl to be set")
193191
194192 if self.saml2_grandfathered_mxid_source_attribute:
195193 optional_attributes.add(self.saml2_grandfathered_mxid_source_attribute)
2525 from netaddr import IPSet
2626
2727 from synapse.api.room_versions import KNOWN_ROOM_VERSIONS
28 from synapse.http.endpoint import parse_and_validate_server_name
28 from synapse.util.stringutils import parse_and_validate_server_name
2929
3030 from ._base import Config, ConfigError
3131
160160 self.print_pidfile = config.get("print_pidfile")
161161 self.user_agent_suffix = config.get("user_agent_suffix")
162162 self.use_frozen_dicts = config.get("use_frozen_dicts", False)
163 self.public_baseurl = config.get("public_baseurl")
163 self.public_baseurl = config.get("public_baseurl") or "https://%s/" % (
164 self.server_name,
165 )
166 if self.public_baseurl[-1] != "/":
167 self.public_baseurl += "/"
164168
165169 # Whether to enable user presence.
166170 self.use_presence = config.get("use_presence", True)
316320 # Always blacklist 0.0.0.0, ::
317321 self.federation_ip_range_blacklist.update(["0.0.0.0", "::"])
318322
319 if self.public_baseurl is not None:
320 if self.public_baseurl[-1] != "/":
321 self.public_baseurl += "/"
322323 self.start_pushers = config.get("start_pushers", True)
323324
324325 # (undocumented) option for torturing the worker-mode replication a bit,
739740 #
740741 #web_client_location: https://riot.example.com/
741742
742 # The public-facing base URL that clients use to access this HS
743 # (not including _matrix/...). This is the same URL a user would
744 # enter into the 'custom HS URL' field on their client. If you
745 # use synapse with a reverse proxy, this should be the URL to reach
746 # synapse via the proxy.
743 # The public-facing base URL that clients use to access this Homeserver (not
744 # including _matrix/...). This is the same URL a user might enter into the
745 # 'Custom Homeserver URL' field on their client. If you use Synapse with a
746 # reverse proxy, this should be the URL to reach Synapse via the proxy.
747 # Otherwise, it should be the URL to reach Synapse's client HTTP listener (see
748 # 'listeners' below).
749 #
750 # If this is left unset, it defaults to 'https://<server_name>/'. (Note that
751 # that will not work unless you configure Synapse or a reverse-proxy to listen
752 # on port 443.)
747753 #
748754 #public_baseurl: https://example.com/
749755
3030
3131 # Read templates from disk
3232 (
33 self.sso_login_idp_picker_template,
3334 self.sso_redirect_confirm_template,
3435 self.sso_auth_confirm_template,
3536 self.sso_error_template,
3637 sso_account_deactivated_template,
3738 sso_auth_success_template,
39 self.sso_auth_bad_user_template,
3840 ) = self.read_templates(
3941 [
42 "sso_login_idp_picker.html",
4043 "sso_redirect_confirm.html",
4144 "sso_auth_confirm.html",
4245 "sso_error.html",
4346 "sso_account_deactivated.html",
4447 "sso_auth_success.html",
48 "sso_auth_bad_user.html",
4549 ],
4650 template_dir,
4751 )
5963 # gracefully to the client). This would make it pointless to ask the user for
6064 # confirmation, since the URL the confirmation page would be showing wouldn't be
6165 # the client's.
62 # public_baseurl is an optional setting, so we only add the fallback's URL to the
63 # list if it's provided (because we can't figure out what that URL is otherwise).
64 if self.public_baseurl:
65 login_fallback_url = self.public_baseurl + "_matrix/static/client/login"
66 self.sso_client_whitelist.append(login_fallback_url)
66 login_fallback_url = self.public_baseurl + "_matrix/static/client/login"
67 self.sso_client_whitelist.append(login_fallback_url)
6768
6869 def generate_config_section(self, **kwargs):
6970 return """\
8182 # phishing attacks from evil.site. To avoid this, include a slash after the
8283 # hostname: "https://my.client/".
8384 #
84 # If public_baseurl is set, then the login fallback page (used by clients
85 # that don't natively support the required login flows) is whitelisted in
86 # addition to any URLs in this list.
85 # The login fallback page (used by clients that don't natively support the
86 # required login flows) is automatically whitelisted in addition to any URLs
87 # in this list.
8788 #
8889 # By default, this list is empty.
8990 #
9697 # directory, default templates from within the Synapse package will be used.
9798 #
9899 # Synapse will look for the following templates in this directory:
100 #
101 # * HTML page to prompt the user to choose an Identity Provider during
102 # login: 'sso_login_idp_picker.html'.
103 #
104 # This is only used if multiple SSO Identity Providers are configured.
105 #
106 # When rendering, this template is given the following variables:
107 # * redirect_url: the URL that the user will be redirected to after
108 # login. Needs manual escaping (see
109 # https://jinja.palletsprojects.com/en/2.11.x/templates/#html-escaping).
110 #
111 # * server_name: the homeserver's name.
112 #
113 # * providers: a list of available Identity Providers. Each element is
114 # an object with the following attributes:
115 # * idp_id: unique identifier for the IdP
116 # * idp_name: user-facing name for the IdP
117 #
118 # The rendered HTML page should contain a form which submits its results
119 # back as a GET request, with the following query parameters:
120 #
121 # * redirectUrl: the client redirect URI (ie, the `redirect_url` passed
122 # to the template)
123 #
124 # * idp: the 'idp_id' of the chosen IDP.
99125 #
100126 # * HTML page for a confirmation step before redirecting back to the client
101127 # with the login token: 'sso_redirect_confirm.html'.
132158 #
133159 # This template has no additional variables.
134160 #
161 # * HTML page shown after a user-interactive authentication session which
162 # does not map correctly onto the expected user: 'sso_auth_bad_user.html'.
163 #
164 # When rendering, this template is given the following variables:
165 # * server_name: the homeserver's name.
166 # * user_id_to_verify: the MXID of the user that we are trying to
167 # validate.
168 #
135169 # * HTML page shown during single sign-on if a deactivated user (according to Synapse's database)
136170 # attempts to login: 'sso_account_deactivated.html'.
137171 #
5252 default=["master"], type=List[str], converter=_instance_to_list_converter
5353 )
5454 typing = attr.ib(default="master", type=str)
55 to_device = attr.ib(
56 default=["master"], type=List[str], converter=_instance_to_list_converter,
57 )
58 account_data = attr.ib(
59 default=["master"], type=List[str], converter=_instance_to_list_converter,
60 )
61 receipts = attr.ib(
62 default=["master"], type=List[str], converter=_instance_to_list_converter,
63 )
5564
5665
5766 class WorkerConfig(Config):
123132
124133 # Check that the configured writers for events and typing also appears in
125134 # `instance_map`.
126 for stream in ("events", "typing"):
135 for stream in ("events", "typing", "to_device", "account_data", "receipts"):
127136 instances = _instance_to_list_converter(getattr(self.writers, stream))
128137 for instance in instances:
129138 if instance != "master" and instance not in self.instance_map:
132141 % (instance, stream)
133142 )
134143
144 if len(self.writers.to_device) != 1:
145 raise ConfigError(
146 "Must only specify one instance to handle `to_device` messages."
147 )
148
149 if len(self.writers.account_data) != 1:
150 raise ConfigError(
151 "Must only specify one instance to handle `account_data` messages."
152 )
153
154 if len(self.writers.receipts) != 1:
155 raise ConfigError(
156 "Must only specify one instance to handle `receipts` messages."
157 )
158
135159 self.events_shard_config = ShardedWorkerHandlingConfig(self.writers.events)
136160
137161 # Whether this worker should run background tasks or not.
1616
1717 import abc
1818 import os
19 from distutils.util import strtobool
2019 from typing import Dict, Optional, Tuple, Type
2120
2221 from unpaddedbase64 import encode_base64
2524 from synapse.types import JsonDict, RoomStreamToken
2625 from synapse.util.caches import intern_dict
2726 from synapse.util.frozenutils import freeze
27 from synapse.util.stringutils import strtobool
2828
2929 # Whether we should use frozen_dict in FrozenEvent. Using frozen_dicts prevents
3030 # bugs where we accidentally share e.g. signature dicts. However, converting a
3333 # NOTE: This is overridden by the configuration by the Synapse worker apps, but
3434 # for the sake of tests, it is set here while it cannot be configured on the
3535 # homeserver object itself.
36
3637 USE_FROZEN_DICTS = strtobool(os.environ.get("SYNAPSE_USE_FROZEN_DICTS", "0"))
3738
3839
7878 "state_key",
7979 "depth",
8080 "prev_events",
81 "prev_state",
8281 "auth_events",
8382 "origin",
8483 "origin_server_ts",
85 "membership",
8684 ]
85
86 # Room versions from before MSC2176 had additional allowed keys.
87 if not room_version.msc2176_redaction_rules:
88 allowed_keys.extend(["prev_state", "membership"])
8789
8890 event_type = event_dict["type"]
8991
9799 if event_type == EventTypes.Member:
98100 add_fields("membership")
99101 elif event_type == EventTypes.Create:
102 # MSC2176 rules state that create events cannot be redacted.
103 if room_version.msc2176_redaction_rules:
104 return event_dict
105
100106 add_fields("creator")
101107 elif event_type == EventTypes.JoinRules:
102108 add_fields("join_rule")
111117 "kick",
112118 "redact",
113119 )
120
121 if room_version.msc2176_redaction_rules:
122 add_fields("invite")
123
114124 elif event_type == EventTypes.Aliases and room_version.special_case_aliases_auth:
115125 add_fields("aliases")
116126 elif event_type == EventTypes.RoomHistoryVisibility:
117127 add_fields("history_visibility")
128 elif event_type == EventTypes.Redaction and room_version.msc2176_redaction_rules:
129 add_fields("redacts")
118130
119131 allowed_fields = {k: v for k, v in event_dict.items() if k in allowed_keys}
120132
1414 # See the License for the specific language governing permissions and
1515 # limitations under the License.
1616 import logging
17 import random
1718 from typing import (
1819 TYPE_CHECKING,
1920 Any,
4748 from synapse.federation.federation_base import FederationBase, event_from_pdu_json
4849 from synapse.federation.persistence import TransactionActions
4950 from synapse.federation.units import Edu, Transaction
50 from synapse.http.endpoint import parse_server_name
5151 from synapse.http.servlet import assert_params_in_dict
5252 from synapse.logging.context import (
5353 make_deferred_yieldable,
6464 from synapse.util import glob_to_regex, json_decoder, unwrapFirstError
6565 from synapse.util.async_helpers import Linearizer, concurrently_execute
6666 from synapse.util.caches.response_cache import ResponseCache
67 from synapse.util.stringutils import parse_server_name
6768
6869 if TYPE_CHECKING:
6970 from synapse.server import HomeServer
859860 ) # type: Dict[str, Callable[[str, dict], Awaitable[None]]]
860861 self.query_handlers = {} # type: Dict[str, Callable[[dict], Awaitable[None]]]
861862
862 # Map from type to instance name that we should route EDU handling to.
863 self._edu_type_to_instance = {} # type: Dict[str, str]
863 # Map from type to instance names that we should route EDU handling to.
864 # We randomly choose one instance from the list to route to for each new
865 # EDU received.
866 self._edu_type_to_instance = {} # type: Dict[str, List[str]]
864867
865868 def register_edu_handler(
866869 self, edu_type: str, handler: Callable[[str, JsonDict], Awaitable[None]]
904907 def register_instance_for_edu(self, edu_type: str, instance_name: str):
905908 """Register that the EDU handler is on a different instance than master.
906909 """
907 self._edu_type_to_instance[edu_type] = instance_name
910 self._edu_type_to_instance[edu_type] = [instance_name]
911
912 def register_instances_for_edu(self, edu_type: str, instance_names: List[str]):
913 """Register that the EDU handler is on multiple instances.
914 """
915 self._edu_type_to_instance[edu_type] = instance_names
908916
909917 async def on_edu(self, edu_type: str, origin: str, content: dict):
910918 if not self.config.use_presence and edu_type == "m.presence":
923931 return
924932
925933 # Check if we can route it somewhere else that isn't us
926 route_to = self._edu_type_to_instance.get(edu_type, "master")
927 if route_to != self._instance_name:
934 instances = self._edu_type_to_instance.get(edu_type, ["master"])
935 if self._instance_name not in instances:
936 # Pick an instance randomly so that we don't overload one.
937 route_to = random.choice(instances)
938
928939 try:
929940 await self._send_edu(
930941 instance_name=route_to,
2727 FEDERATION_V1_PREFIX,
2828 FEDERATION_V2_PREFIX,
2929 )
30 from synapse.http.endpoint import parse_and_validate_server_name
3130 from synapse.http.server import JsonResource
3231 from synapse.http.servlet import (
3332 parse_boolean_from_args,
4443 )
4544 from synapse.server import HomeServer
4645 from synapse.types import ThirdPartyInstanceID, get_domain_from_id
46 from synapse.util.stringutils import parse_and_validate_server_name
4747 from synapse.util.versionstring import get_version_string
4848
4949 logger = logging.getLogger(__name__)
00 # -*- coding: utf-8 -*-
11 # Copyright 2015, 2016 OpenMarket Ltd
2 # Copyright 2021 The Matrix.org Foundation C.I.C.
23 #
34 # Licensed under the Apache License, Version 2.0 (the "License");
45 # you may not use this file except in compliance with the License.
1112 # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
1213 # See the License for the specific language governing permissions and
1314 # limitations under the License.
15 import random
1416 from typing import TYPE_CHECKING, List, Tuple
1517
18 from synapse.replication.http.account_data import (
19 ReplicationAddTagRestServlet,
20 ReplicationRemoveTagRestServlet,
21 ReplicationRoomAccountDataRestServlet,
22 ReplicationUserAccountDataRestServlet,
23 )
1624 from synapse.types import JsonDict, UserID
1725
1826 if TYPE_CHECKING:
1927 from synapse.app.homeserver import HomeServer
28
29
30 class AccountDataHandler:
31 def __init__(self, hs: "HomeServer"):
32 self._store = hs.get_datastore()
33 self._instance_name = hs.get_instance_name()
34 self._notifier = hs.get_notifier()
35
36 self._user_data_client = ReplicationUserAccountDataRestServlet.make_client(hs)
37 self._room_data_client = ReplicationRoomAccountDataRestServlet.make_client(hs)
38 self._add_tag_client = ReplicationAddTagRestServlet.make_client(hs)
39 self._remove_tag_client = ReplicationRemoveTagRestServlet.make_client(hs)
40 self._account_data_writers = hs.config.worker.writers.account_data
41
42 async def add_account_data_to_room(
43 self, user_id: str, room_id: str, account_data_type: str, content: JsonDict
44 ) -> int:
45 """Add some account_data to a room for a user.
46
47 Args:
48 user_id: The user to add a tag for.
49 room_id: The room to add a tag for.
50 account_data_type: The type of account_data to add.
51 content: A json object to associate with the tag.
52
53 Returns:
54 The maximum stream ID.
55 """
56 if self._instance_name in self._account_data_writers:
57 max_stream_id = await self._store.add_account_data_to_room(
58 user_id, room_id, account_data_type, content
59 )
60
61 self._notifier.on_new_event(
62 "account_data_key", max_stream_id, users=[user_id]
63 )
64
65 return max_stream_id
66 else:
67 response = await self._room_data_client(
68 instance_name=random.choice(self._account_data_writers),
69 user_id=user_id,
70 room_id=room_id,
71 account_data_type=account_data_type,
72 content=content,
73 )
74 return response["max_stream_id"]
75
76 async def add_account_data_for_user(
77 self, user_id: str, account_data_type: str, content: JsonDict
78 ) -> int:
79 """Add some account_data to a room for a user.
80
81 Args:
82 user_id: The user to add a tag for.
83 account_data_type: The type of account_data to add.
84 content: A json object to associate with the tag.
85
86 Returns:
87 The maximum stream ID.
88 """
89
90 if self._instance_name in self._account_data_writers:
91 max_stream_id = await self._store.add_account_data_for_user(
92 user_id, account_data_type, content
93 )
94
95 self._notifier.on_new_event(
96 "account_data_key", max_stream_id, users=[user_id]
97 )
98 return max_stream_id
99 else:
100 response = await self._user_data_client(
101 instance_name=random.choice(self._account_data_writers),
102 user_id=user_id,
103 account_data_type=account_data_type,
104 content=content,
105 )
106 return response["max_stream_id"]
107
108 async def add_tag_to_room(
109 self, user_id: str, room_id: str, tag: str, content: JsonDict
110 ) -> int:
111 """Add a tag to a room for a user.
112
113 Args:
114 user_id: The user to add a tag for.
115 room_id: The room to add a tag for.
116 tag: The tag name to add.
117 content: A json object to associate with the tag.
118
119 Returns:
120 The next account data ID.
121 """
122 if self._instance_name in self._account_data_writers:
123 max_stream_id = await self._store.add_tag_to_room(
124 user_id, room_id, tag, content
125 )
126
127 self._notifier.on_new_event(
128 "account_data_key", max_stream_id, users=[user_id]
129 )
130 return max_stream_id
131 else:
132 response = await self._add_tag_client(
133 instance_name=random.choice(self._account_data_writers),
134 user_id=user_id,
135 room_id=room_id,
136 tag=tag,
137 content=content,
138 )
139 return response["max_stream_id"]
140
141 async def remove_tag_from_room(self, user_id: str, room_id: str, tag: str) -> int:
142 """Remove a tag from a room for a user.
143
144 Returns:
145 The next account data ID.
146 """
147 if self._instance_name in self._account_data_writers:
148 max_stream_id = await self._store.remove_tag_from_room(
149 user_id, room_id, tag
150 )
151
152 self._notifier.on_new_event(
153 "account_data_key", max_stream_id, users=[user_id]
154 )
155 return max_stream_id
156 else:
157 response = await self._remove_tag_client(
158 instance_name=random.choice(self._account_data_writers),
159 user_id=user_id,
160 room_id=room_id,
161 tag=tag,
162 )
163 return response["max_stream_id"]
20164
21165
22166 class AccountDataEventSource:
4848 UserDeactivatedError,
4949 )
5050 from synapse.api.ratelimiting import Ratelimiter
51 from synapse.handlers.ui_auth import INTERACTIVE_AUTH_CHECKERS
51 from synapse.handlers._base import BaseHandler
52 from synapse.handlers.ui_auth import (
53 INTERACTIVE_AUTH_CHECKERS,
54 UIAuthSessionDataConstants,
55 )
5256 from synapse.handlers.ui_auth.checkers import UserInteractiveAuthChecker
57 from synapse.http import get_request_user_agent
5358 from synapse.http.server import finish_request, respond_with_html
5459 from synapse.http.site import SynapseRequest
5560 from synapse.logging.context import defer_to_thread
6065 from synapse.util.async_helpers import maybe_awaitable
6166 from synapse.util.msisdn import phone_number_to_msisdn
6267 from synapse.util.threepids import canonicalise_email
63
64 from ._base import BaseHandler
6568
6669 if TYPE_CHECKING:
6770 from synapse.app.homeserver import HomeServer
259262 # authenticating for an operation to occur on their account.
260263 self._sso_auth_confirm_template = hs.config.sso_auth_confirm_template
261264
262 # The following template is shown after a successful user interactive
263 # authentication session. It tells the user they can close the window.
264 self._sso_auth_success_template = hs.config.sso_auth_success_template
265
266265 # The following template is shown during the SSO authentication process if
267266 # the account is deactivated.
268267 self._sso_account_deactivated_template = (
283282 requester: Requester,
284283 request: SynapseRequest,
285284 request_body: Dict[str, Any],
286 clientip: str,
287285 description: str,
288286 ) -> Tuple[dict, Optional[str]]:
289287 """
299297 request: The request sent by the client.
300298
301299 request_body: The body of the request sent by the client
302
303 clientip: The IP address of the client.
304300
305301 description: A human readable string to be displayed to the user that
306302 describes the operation happening on their account.
337333 request_body.pop("auth", None)
338334 return request_body, None
339335
340 user_id = requester.user.to_string()
336 requester_user_id = requester.user.to_string()
341337
342338 # Check if we should be ratelimited due to too many previous failed attempts
343 self._failed_uia_attempts_ratelimiter.ratelimit(user_id, update=False)
339 self._failed_uia_attempts_ratelimiter.ratelimit(requester_user_id, update=False)
344340
345341 # build a list of supported flows
346342 supported_ui_auth_types = await self._get_available_ui_auth_types(
348344 )
349345 flows = [[login_type] for login_type in supported_ui_auth_types]
350346
347 def get_new_session_data() -> JsonDict:
348 return {UIAuthSessionDataConstants.REQUEST_USER_ID: requester_user_id}
349
351350 try:
352351 result, params, session_id = await self.check_ui_auth(
353 flows, request, request_body, clientip, description
352 flows, request, request_body, description, get_new_session_data,
354353 )
355354 except LoginError:
356355 # Update the ratelimiter to say we failed (`can_do_action` doesn't raise).
357 self._failed_uia_attempts_ratelimiter.can_do_action(user_id)
356 self._failed_uia_attempts_ratelimiter.can_do_action(requester_user_id)
358357 raise
359358
360359 # find the completed login type
362361 if login_type not in result:
363362 continue
364363
365 user_id = result[login_type]
364 validated_user_id = result[login_type]
366365 break
367366 else:
368367 # this can't happen
369368 raise Exception("check_auth returned True but no successful login type")
370369
371370 # check that the UI auth matched the access token
372 if user_id != requester.user.to_string():
371 if validated_user_id != requester_user_id:
373372 raise AuthError(403, "Invalid auth")
374373
375374 # Note that the access token has been validated.
401400
402401 # if sso is enabled, allow the user to log in via SSO iff they have a mapping
403402 # from sso to mxid.
404 if self.hs.config.saml2.saml2_enabled or self.hs.config.oidc.oidc_enabled:
405 if await self.store.get_external_ids_by_user(user.to_string()):
406 ui_auth_types.add(LoginType.SSO)
407
408 # Our CAS impl does not (yet) correctly register users in user_external_ids,
409 # so always offer that if it's available.
410 if self.hs.config.cas.cas_enabled:
403 if await self.hs.get_sso_handler().get_identity_providers_for_user(
404 user.to_string()
405 ):
411406 ui_auth_types.add(LoginType.SSO)
412407
413408 return ui_auth_types
425420 flows: List[List[str]],
426421 request: SynapseRequest,
427422 clientdict: Dict[str, Any],
428 clientip: str,
429423 description: str,
424 get_new_session_data: Optional[Callable[[], JsonDict]] = None,
430425 ) -> Tuple[dict, dict, str]:
431426 """
432427 Takes a dictionary sent by the client in the login / registration
447442 clientdict: The dictionary from the client root level, not the
448443 'auth' key: this method prompts for auth if none is sent.
449444
450 clientip: The IP address of the client.
451
452445 description: A human readable string to be displayed to the user that
453446 describes the operation happening on their account.
447
448 get_new_session_data:
449 an optional callback which will be called when starting a new session.
450 it should return data to be stored as part of the session.
451
452 The keys of the returned data should be entries in
453 UIAuthSessionDataConstants.
454454
455455 Returns:
456456 A tuple of (creds, params, session_id).
479479
480480 # If there's no session ID, create a new session.
481481 if not sid:
482 new_session_data = get_new_session_data() if get_new_session_data else {}
483
482484 session = await self.store.create_ui_auth_session(
483485 clientdict, uri, method, description
484486 )
487
488 for k, v in new_session_data.items():
489 await self.set_session_data(session.session_id, k, v)
485490
486491 else:
487492 try:
538543 # authentication flow.
539544 await self.store.set_ui_auth_clientdict(sid, clientdict)
540545
541 user_agent = request.get_user_agent("")
546 user_agent = get_request_user_agent(request)
547 clientip = request.getClientIP()
542548
543549 await self.store.add_user_agent_ip_to_ui_auth_session(
544550 session.session_id, user_agent, clientip
643649
644650 Args:
645651 session_id: The ID of this session as returned from check_auth
646 key: The key to store the data under
652 key: The key to store the data under. An entry from
653 UIAuthSessionDataConstants.
647654 value: The data to store
648655 """
649656 try:
659666
660667 Args:
661668 session_id: The ID of this session as returned from check_auth
662 key: The key to store the data under
669 key: The key the data was stored under. An entry from
670 UIAuthSessionDataConstants.
663671 default: Value to return if the key has not been set
664672 """
665673 try:
13331341 else:
13341342 return False
13351343
1336 async def start_sso_ui_auth(self, redirect_url: str, session_id: str) -> str:
1344 async def start_sso_ui_auth(self, request: SynapseRequest, session_id: str) -> str:
13371345 """
13381346 Get the HTML for the SSO redirect confirmation page.
13391347
13401348 Args:
1341 redirect_url: The URL to redirect to the SSO provider.
1349 request: The incoming HTTP request
13421350 session_id: The user interactive authentication session ID.
13431351
13441352 Returns:
13481356 session = await self.store.get_ui_auth_session(session_id)
13491357 except StoreError:
13501358 raise SynapseError(400, "Unknown session ID: %s" % (session_id,))
1359
1360 user_id_to_verify = await self.get_session_data(
1361 session_id, UIAuthSessionDataConstants.REQUEST_USER_ID
1362 ) # type: str
1363
1364 idps = await self.hs.get_sso_handler().get_identity_providers_for_user(
1365 user_id_to_verify
1366 )
1367
1368 if not idps:
1369 # we checked that the user had some remote identities before offering an SSO
1370 # flow, so either it's been deleted or the client has requested SSO despite
1371 # it not being offered.
1372 raise SynapseError(400, "User has no SSO identities")
1373
1374 # for now, just pick one
1375 idp_id, sso_auth_provider = next(iter(idps.items()))
1376 if len(idps) > 0:
1377 logger.warning(
1378 "User %r has previously logged in with multiple SSO IdPs; arbitrarily "
1379 "picking %r",
1380 user_id_to_verify,
1381 idp_id,
1382 )
1383
1384 redirect_url = await sso_auth_provider.handle_redirect_request(
1385 request, None, session_id
1386 )
1387
13511388 return self._sso_auth_confirm_template.render(
13521389 description=session.description, redirect_url=redirect_url,
13531390 )
1354
1355 async def complete_sso_ui_auth(
1356 self, registered_user_id: str, session_id: str, request: Request,
1357 ):
1358 """Having figured out a mxid for this user, complete the HTTP request
1359
1360 Args:
1361 registered_user_id: The registered user ID to complete SSO login for.
1362 session_id: The ID of the user-interactive auth session.
1363 request: The request to complete.
1364 """
1365 # Mark the stage of the authentication as successful.
1366 # Save the user who authenticated with SSO, this will be used to ensure
1367 # that the account be modified is also the person who logged in.
1368 await self.store.mark_ui_auth_stage_complete(
1369 session_id, LoginType.SSO, registered_user_id
1370 )
1371
1372 # Render the HTML and return.
1373 html = self._sso_auth_success_template
1374 respond_with_html(request, 200, html)
13751391
13761392 async def complete_sso_login(
13771393 self,
14871503 @staticmethod
14881504 def add_query_param_to_url(url: str, param_name: str, param: Any):
14891505 url_parts = list(urllib.parse.urlparse(url))
1490 query = dict(urllib.parse.parse_qsl(url_parts[4]))
1491 query.update({param_name: param})
1506 query = urllib.parse.parse_qsl(url_parts[4], keep_blank_values=True)
1507 query.append((param_name, param))
14921508 url_parts[4] = urllib.parse.urlencode(query)
14931509 return urllib.parse.urlunparse(url_parts)
14941510
7474 self._http_client = hs.get_proxied_http_client()
7575
7676 # identifier for the external_ids table
77 self._auth_provider_id = "cas"
77 self.idp_id = "cas"
78
79 # user-facing name of this auth provider
80 self.idp_name = "CAS"
81
82 # we do not currently support icons for CAS auth, but this is required by
83 # the SsoIdentityProvider protocol type.
84 self.idp_icon = None
7885
7986 self._sso_handler = hs.get_sso_handler()
87
88 self._sso_handler.register_identity_provider(self)
8089
8190 def _build_service_param(self, args: Dict[str, str]) -> str:
8291 """
104113 Args:
105114 ticket: The CAS ticket from the client.
106115 service_args: Additional arguments to include in the service URL.
107 Should be the same as those passed to `get_redirect_url`.
116 Should be the same as those passed to `handle_redirect_request`.
108117
109118 Raises:
110119 CasError: If there's an error parsing the CAS response.
183192
184193 return CasResponse(user, attributes)
185194
186 def get_redirect_url(self, service_args: Dict[str, str]) -> str:
187 """
188 Generates a URL for the CAS server where the client should be redirected.
189
190 Args:
191 service_args: Additional arguments to include in the final redirect URL.
195 async def handle_redirect_request(
196 self,
197 request: SynapseRequest,
198 client_redirect_url: Optional[bytes],
199 ui_auth_session_id: Optional[str] = None,
200 ) -> str:
201 """Generates a URL for the CAS server where the client should be redirected.
202
203 Args:
204 request: the incoming HTTP request
205 client_redirect_url: the URL that we should redirect the
206 client to after login (or None for UI Auth).
207 ui_auth_session_id: The session ID of the ongoing UI Auth (or
208 None if this is a login).
192209
193210 Returns:
194 The URL to redirect the client to.
195 """
211 URL to redirect to
212 """
213
214 if ui_auth_session_id:
215 service_args = {"session": ui_auth_session_id}
216 else:
217 assert client_redirect_url
218 service_args = {"redirectUrl": client_redirect_url.decode("utf8")}
219
196220 args = urllib.parse.urlencode(
197221 {"service": self._build_service_param(service_args)}
198222 )
274298 # first check if we're doing a UIA
275299 if session:
276300 return await self._sso_handler.complete_sso_ui_auth_request(
277 self._auth_provider_id, cas_response.username, session, request,
301 self.idp_id, cas_response.username, session, request,
278302 )
279303
280304 # otherwise, we're handling a login request.
374398 return None
375399
376400 await self._sso_handler.complete_sso_login_request(
377 self._auth_provider_id,
401 self.idp_id,
378402 cas_response.username,
379403 request,
380404 client_redirect_url,
1717
1818 from synapse.api.errors import SynapseError
1919 from synapse.metrics.background_process_metrics import run_as_background_process
20 from synapse.types import UserID, create_requester
20 from synapse.types import Requester, UserID, create_requester
2121
2222 from ._base import BaseHandler
2323
3737 self._device_handler = hs.get_device_handler()
3838 self._room_member_handler = hs.get_room_member_handler()
3939 self._identity_handler = hs.get_identity_handler()
40 self._profile_handler = hs.get_profile_handler()
4041 self.user_directory_handler = hs.get_user_directory_handler()
4142 self._server_name = hs.hostname
4243
5152 self._account_validity_enabled = hs.config.account_validity.enabled
5253
5354 async def deactivate_account(
54 self, user_id: str, erase_data: bool, id_server: Optional[str] = None
55 self,
56 user_id: str,
57 erase_data: bool,
58 requester: Requester,
59 id_server: Optional[str] = None,
60 by_admin: bool = False,
5561 ) -> bool:
5662 """Deactivate a user's account
5763
5864 Args:
5965 user_id: ID of user to be deactivated
6066 erase_data: whether to GDPR-erase the user's data
67 requester: The user attempting to make this change.
6168 id_server: Use the given identity server when unbinding
6269 any threepids. If None then will attempt to unbind using the
6370 identity server specified when binding (if known).
71 by_admin: Whether this change was made by an administrator.
6472
6573 Returns:
6674 True if identity server supports removing threepids, otherwise False.
120128
121129 # Mark the user as erased, if they asked for that
122130 if erase_data:
131 user = UserID.from_string(user_id)
132 # Remove avatar URL from this user
133 await self._profile_handler.set_avatar_url(user, requester, "", by_admin)
134 # Remove displayname from this user
135 await self._profile_handler.set_displayname(user, requester, "", by_admin)
136
123137 logger.info("Marking %s as erased", user_id)
124138 await self.store.mark_user_erased(user_id)
125139
2323 set_tag,
2424 start_active_span,
2525 )
26 from synapse.replication.http.devices import ReplicationUserDevicesResyncRestServlet
2627 from synapse.types import JsonDict, UserID, get_domain_from_id
2728 from synapse.util import json_encoder
2829 from synapse.util.stringutils import random_string
4344 self.store = hs.get_datastore()
4445 self.notifier = hs.get_notifier()
4546 self.is_mine = hs.is_mine
46 self.federation = hs.get_federation_sender()
47
48 hs.get_federation_registry().register_edu_handler(
49 "m.direct_to_device", self.on_direct_to_device_edu
50 )
51
52 self._device_list_updater = hs.get_device_handler().device_list_updater
47
48 # We only need to poke the federation sender explicitly if its on the
49 # same instance. Other federation sender instances will get notified by
50 # `synapse.app.generic_worker.FederationSenderHandler` when it sees it
51 # in the to-device replication stream.
52 self.federation_sender = None
53 if hs.should_send_federation():
54 self.federation_sender = hs.get_federation_sender()
55
56 # If we can handle the to device EDUs we do so, otherwise we route them
57 # to the appropriate worker.
58 if hs.get_instance_name() in hs.config.worker.writers.to_device:
59 hs.get_federation_registry().register_edu_handler(
60 "m.direct_to_device", self.on_direct_to_device_edu
61 )
62 else:
63 hs.get_federation_registry().register_instances_for_edu(
64 "m.direct_to_device", hs.config.worker.writers.to_device,
65 )
66
67 # The handler to call when we think a user's device list might be out of
68 # sync. We do all device list resyncing on the master instance, so if
69 # we're on a worker we hit the device resync replication API.
70 if hs.config.worker.worker_app is None:
71 self._user_device_resync = (
72 hs.get_device_handler().device_list_updater.user_device_resync
73 )
74 else:
75 self._user_device_resync = ReplicationUserDevicesResyncRestServlet.make_client(
76 hs
77 )
5378
5479 async def on_direct_to_device_edu(self, origin: str, content: JsonDict) -> None:
5580 local_messages = {}
137162 await self.store.mark_remote_user_device_cache_as_stale(sender_user_id)
138163
139164 # Immediately attempt a resync in the background
140 run_in_background(
141 self._device_list_updater.user_device_resync, sender_user_id
142 )
165 run_in_background(self._user_device_resync, user_id=sender_user_id)
143166
144167 async def send_device_message(
145168 self,
194217 )
195218
196219 log_kv({"remote_messages": remote_messages})
197 for destination in remote_messages.keys():
198 # Enqueue a new federation transaction to send the new
199 # device messages to each remote destination.
200 self.federation.send_device_messages(destination)
220 if self.federation_sender:
221 for destination in remote_messages.keys():
222 # Enqueue a new federation transaction to send the new
223 # device messages to each remote destination.
224 self.federation_sender.send_device_messages(destination)
474474 raise e.to_synapse_error()
475475 except RequestTimedOutError:
476476 raise SynapseError(500, "Timed out contacting identity server")
477
478 assert self.hs.config.public_baseurl
479477
480478 # we need to tell the client to send the token back to us, since it doesn't
481479 # otherwise know where to send it, so add submit_url response parameter
1313 # limitations under the License.
1414 import inspect
1515 import logging
16 from typing import TYPE_CHECKING, Dict, Generic, List, Optional, Tuple, TypeVar
16 from typing import TYPE_CHECKING, Dict, Generic, List, Optional, TypeVar
1717 from urllib.parse import urlencode
1818
1919 import attr
3434 from twisted.web.client import readBody
3535
3636 from synapse.config import ConfigError
37 from synapse.handlers._base import BaseHandler
37 from synapse.config.oidc_config import OidcProviderConfig
3838 from synapse.handlers.sso import MappingException, UserAttributes
3939 from synapse.http.site import SynapseRequest
4040 from synapse.logging.context import make_deferred_yieldable
7070 JWKS = TypedDict("JWKS", {"keys": List[JWK]})
7171
7272
73 class OidcHandler:
74 """Handles requests related to the OpenID Connect login flow.
75 """
76
77 def __init__(self, hs: "HomeServer"):
78 self._sso_handler = hs.get_sso_handler()
79
80 provider_confs = hs.config.oidc.oidc_providers
81 # we should not have been instantiated if there is no configured provider.
82 assert provider_confs
83
84 self._token_generator = OidcSessionTokenGenerator(hs)
85 self._providers = {
86 p.idp_id: OidcProvider(hs, self._token_generator, p) for p in provider_confs
87 } # type: Dict[str, OidcProvider]
88
89 async def load_metadata(self) -> None:
90 """Validate the config and load the metadata from the remote endpoint.
91
92 Called at startup to ensure we have everything we need.
93 """
94 for idp_id, p in self._providers.items():
95 try:
96 await p.load_metadata()
97 await p.load_jwks()
98 except Exception as e:
99 raise Exception(
100 "Error while initialising OIDC provider %r" % (idp_id,)
101 ) from e
102
103 async def handle_oidc_callback(self, request: SynapseRequest) -> None:
104 """Handle an incoming request to /_synapse/oidc/callback
105
106 Since we might want to display OIDC-related errors in a user-friendly
107 way, we don't raise SynapseError from here. Instead, we call
108 ``self._sso_handler.render_error`` which displays an HTML page for the error.
109
110 Most of the OpenID Connect logic happens here:
111
112 - first, we check if there was any error returned by the provider and
113 display it
114 - then we fetch the session cookie, decode and verify it
115 - the ``state`` query parameter should match with the one stored in the
116 session cookie
117
118 Once we know the session is legit, we then delegate to the OIDC Provider
119 implementation, which will exchange the code with the provider and complete the
120 login/authentication.
121
122 Args:
123 request: the incoming request from the browser.
124 """
125
126 # The provider might redirect with an error.
127 # In that case, just display it as-is.
128 if b"error" in request.args:
129 # error response from the auth server. see:
130 # https://tools.ietf.org/html/rfc6749#section-4.1.2.1
131 # https://openid.net/specs/openid-connect-core-1_0.html#AuthError
132 error = request.args[b"error"][0].decode()
133 description = request.args.get(b"error_description", [b""])[0].decode()
134
135 # Most of the errors returned by the provider could be due by
136 # either the provider misbehaving or Synapse being misconfigured.
137 # The only exception of that is "access_denied", where the user
138 # probably cancelled the login flow. In other cases, log those errors.
139 if error != "access_denied":
140 logger.error("Error from the OIDC provider: %s %s", error, description)
141
142 self._sso_handler.render_error(request, error, description)
143 return
144
145 # otherwise, it is presumably a successful response. see:
146 # https://tools.ietf.org/html/rfc6749#section-4.1.2
147
148 # Fetch the session cookie
149 session = request.getCookie(SESSION_COOKIE_NAME) # type: Optional[bytes]
150 if session is None:
151 logger.info("No session cookie found")
152 self._sso_handler.render_error(
153 request, "missing_session", "No session cookie found"
154 )
155 return
156
157 # Remove the cookie. There is a good chance that if the callback failed
158 # once, it will fail next time and the code will already be exchanged.
159 # Removing it early avoids spamming the provider with token requests.
160 request.addCookie(
161 SESSION_COOKIE_NAME,
162 b"",
163 path="/_synapse/oidc",
164 expires="Thu, Jan 01 1970 00:00:00 UTC",
165 httpOnly=True,
166 sameSite="lax",
167 )
168
169 # Check for the state query parameter
170 if b"state" not in request.args:
171 logger.info("State parameter is missing")
172 self._sso_handler.render_error(
173 request, "invalid_request", "State parameter is missing"
174 )
175 return
176
177 state = request.args[b"state"][0].decode()
178
179 # Deserialize the session token and verify it.
180 try:
181 session_data = self._token_generator.verify_oidc_session_token(
182 session, state
183 )
184 except (MacaroonDeserializationException, ValueError) as e:
185 logger.exception("Invalid session")
186 self._sso_handler.render_error(request, "invalid_session", str(e))
187 return
188 except MacaroonInvalidSignatureException as e:
189 logger.exception("Could not verify session")
190 self._sso_handler.render_error(request, "mismatching_session", str(e))
191 return
192
193 oidc_provider = self._providers.get(session_data.idp_id)
194 if not oidc_provider:
195 logger.error("OIDC session uses unknown IdP %r", oidc_provider)
196 self._sso_handler.render_error(request, "unknown_idp", "Unknown IdP")
197 return
198
199 if b"code" not in request.args:
200 logger.info("Code parameter is missing")
201 self._sso_handler.render_error(
202 request, "invalid_request", "Code parameter is missing"
203 )
204 return
205
206 code = request.args[b"code"][0].decode()
207
208 await oidc_provider.handle_oidc_callback(request, session_data, code)
209
210
73211 class OidcError(Exception):
74212 """Used to catch errors when calling the token_endpoint
75213 """
84222 return self.error
85223
86224
87 class OidcHandler(BaseHandler):
88 """Handles requests related to the OpenID Connect login flow.
225 class OidcProvider:
226 """Wraps the config for a single OIDC IdentityProvider
227
228 Provides methods for handling redirect requests and callbacks via that particular
229 IdP.
89230 """
90231
91 def __init__(self, hs: "HomeServer"):
92 super().__init__(hs)
232 def __init__(
233 self,
234 hs: "HomeServer",
235 token_generator: "OidcSessionTokenGenerator",
236 provider: OidcProviderConfig,
237 ):
238 self._store = hs.get_datastore()
239
240 self._token_generator = token_generator
241
93242 self._callback_url = hs.config.oidc_callback_url # type: str
94 self._scopes = hs.config.oidc_scopes # type: List[str]
95 self._user_profile_method = hs.config.oidc_user_profile_method # type: str
243
244 self._scopes = provider.scopes
245 self._user_profile_method = provider.user_profile_method
96246 self._client_auth = ClientAuth(
97 hs.config.oidc_client_id,
98 hs.config.oidc_client_secret,
99 hs.config.oidc_client_auth_method,
247 provider.client_id, provider.client_secret, provider.client_auth_method,
100248 ) # type: ClientAuth
101 self._client_auth_method = hs.config.oidc_client_auth_method # type: str
249 self._client_auth_method = provider.client_auth_method
102250 self._provider_metadata = OpenIDProviderMetadata(
103 issuer=hs.config.oidc_issuer,
104 authorization_endpoint=hs.config.oidc_authorization_endpoint,
105 token_endpoint=hs.config.oidc_token_endpoint,
106 userinfo_endpoint=hs.config.oidc_userinfo_endpoint,
107 jwks_uri=hs.config.oidc_jwks_uri,
251 issuer=provider.issuer,
252 authorization_endpoint=provider.authorization_endpoint,
253 token_endpoint=provider.token_endpoint,
254 userinfo_endpoint=provider.userinfo_endpoint,
255 jwks_uri=provider.jwks_uri,
108256 ) # type: OpenIDProviderMetadata
109 self._provider_needs_discovery = hs.config.oidc_discover # type: bool
110 self._user_mapping_provider = hs.config.oidc_user_mapping_provider_class(
111 hs.config.oidc_user_mapping_provider_config
112 ) # type: OidcMappingProvider
113 self._skip_verification = hs.config.oidc_skip_verification # type: bool
114 self._allow_existing_users = hs.config.oidc_allow_existing_users # type: bool
257 self._provider_needs_discovery = provider.discover
258 self._user_mapping_provider = provider.user_mapping_provider_class(
259 provider.user_mapping_provider_config
260 )
261 self._skip_verification = provider.skip_verification
262 self._allow_existing_users = provider.allow_existing_users
115263
116264 self._http_client = hs.get_proxied_http_client()
117265 self._server_name = hs.config.server_name # type: str
118 self._macaroon_secret_key = hs.config.macaroon_secret_key
119266
120267 # identifier for the external_ids table
121 self._auth_provider_id = "oidc"
268 self.idp_id = provider.idp_id
269
270 # user-facing name of this auth provider
271 self.idp_name = provider.idp_name
272
273 # MXC URI for icon for this auth provider
274 self.idp_icon = provider.idp_icon
122275
123276 self._sso_handler = hs.get_sso_handler()
277
278 self._sso_handler.register_identity_provider(self)
124279
125280 def _validate_metadata(self):
126281 """Verifies the provider metadata.
474629 async def handle_redirect_request(
475630 self,
476631 request: SynapseRequest,
477 client_redirect_url: bytes,
632 client_redirect_url: Optional[bytes],
478633 ui_auth_session_id: Optional[str] = None,
479634 ) -> str:
480635 """Handle an incoming request to /login/sso/redirect
498653 request: the incoming request from the browser.
499654 We'll respond to it with a redirect and a cookie.
500655 client_redirect_url: the URL that we should redirect the client to
501 when everything is done
656 when everything is done (or None for UI Auth)
502657 ui_auth_session_id: The session ID of the ongoing UI Auth (or
503658 None if this is a login).
504659
510665 state = generate_token()
511666 nonce = generate_token()
512667
513 cookie = self._generate_oidc_session_token(
668 if not client_redirect_url:
669 client_redirect_url = b""
670
671 cookie = self._token_generator.generate_oidc_session_token(
514672 state=state,
515 nonce=nonce,
516 client_redirect_url=client_redirect_url.decode(),
517 ui_auth_session_id=ui_auth_session_id,
673 session_data=OidcSessionData(
674 idp_id=self.idp_id,
675 nonce=nonce,
676 client_redirect_url=client_redirect_url.decode(),
677 ui_auth_session_id=ui_auth_session_id,
678 ),
518679 )
519680 request.addCookie(
520681 SESSION_COOKIE_NAME,
537698 nonce=nonce,
538699 )
539700
540 async def handle_oidc_callback(self, request: SynapseRequest) -> None:
701 async def handle_oidc_callback(
702 self, request: SynapseRequest, session_data: "OidcSessionData", code: str
703 ) -> None:
541704 """Handle an incoming request to /_synapse/oidc/callback
542705
543 Since we might want to display OIDC-related errors in a user-friendly
544 way, we don't raise SynapseError from here. Instead, we call
545 ``self._sso_handler.render_error`` which displays an HTML page for the error.
546
547 Most of the OpenID Connect logic happens here:
548
549 - first, we check if there was any error returned by the provider and
550 display it
551 - then we fetch the session cookie, decode and verify it
552 - the ``state`` query parameter should match with the one stored in the
553 session cookie
554 - once we known this session is legit, exchange the code with the
555 provider using the ``token_endpoint`` (see ``_exchange_code``)
706 By this time we have already validated the session on the synapse side, and
707 now need to do the provider-specific operations. This includes:
708
709 - exchange the code with the provider using the ``token_endpoint`` (see
710 ``_exchange_code``)
556711 - once we have the token, use it to either extract the UserInfo from
557712 the ``id_token`` (``_parse_id_token``), or use the ``access_token``
558713 to fetch UserInfo from the ``userinfo_endpoint``
562717
563718 Args:
564719 request: the incoming request from the browser.
565 """
566
567 # The provider might redirect with an error.
568 # In that case, just display it as-is.
569 if b"error" in request.args:
570 # error response from the auth server. see:
571 # https://tools.ietf.org/html/rfc6749#section-4.1.2.1
572 # https://openid.net/specs/openid-connect-core-1_0.html#AuthError
573 error = request.args[b"error"][0].decode()
574 description = request.args.get(b"error_description", [b""])[0].decode()
575
576 # Most of the errors returned by the provider could be due by
577 # either the provider misbehaving or Synapse being misconfigured.
578 # The only exception of that is "access_denied", where the user
579 # probably cancelled the login flow. In other cases, log those errors.
580 if error != "access_denied":
581 logger.error("Error from the OIDC provider: %s %s", error, description)
582
583 self._sso_handler.render_error(request, error, description)
584 return
585
586 # otherwise, it is presumably a successful response. see:
587 # https://tools.ietf.org/html/rfc6749#section-4.1.2
588
589 # Fetch the session cookie
590 session = request.getCookie(SESSION_COOKIE_NAME) # type: Optional[bytes]
591 if session is None:
592 logger.info("No session cookie found")
593 self._sso_handler.render_error(
594 request, "missing_session", "No session cookie found"
595 )
596 return
597
598 # Remove the cookie. There is a good chance that if the callback failed
599 # once, it will fail next time and the code will already be exchanged.
600 # Removing it early avoids spamming the provider with token requests.
601 request.addCookie(
602 SESSION_COOKIE_NAME,
603 b"",
604 path="/_synapse/oidc",
605 expires="Thu, Jan 01 1970 00:00:00 UTC",
606 httpOnly=True,
607 sameSite="lax",
608 )
609
610 # Check for the state query parameter
611 if b"state" not in request.args:
612 logger.info("State parameter is missing")
613 self._sso_handler.render_error(
614 request, "invalid_request", "State parameter is missing"
615 )
616 return
617
618 state = request.args[b"state"][0].decode()
619
620 # Deserialize the session token and verify it.
720 session_data: the session data, extracted from our cookie
721 code: The authorization code we got from the callback.
722 """
723 # Exchange the code with the provider
621724 try:
622 (
623 nonce,
624 client_redirect_url,
625 ui_auth_session_id,
626 ) = self._verify_oidc_session_token(session, state)
627 except MacaroonDeserializationException as e:
628 logger.exception("Invalid session")
629 self._sso_handler.render_error(request, "invalid_session", str(e))
630 return
631 except MacaroonInvalidSignatureException as e:
632 logger.exception("Could not verify session")
633 self._sso_handler.render_error(request, "mismatching_session", str(e))
634 return
635
636 # Exchange the code with the provider
637 if b"code" not in request.args:
638 logger.info("Code parameter is missing")
639 self._sso_handler.render_error(
640 request, "invalid_request", "Code parameter is missing"
641 )
642 return
643
644 logger.debug("Exchanging code")
645 code = request.args[b"code"][0].decode()
646 try:
725 logger.debug("Exchanging code")
647726 token = await self._exchange_code(code)
648727 except OidcError as e:
649728 logger.exception("Could not exchange code")
665744 else:
666745 logger.debug("Extracting userinfo from id_token")
667746 try:
668 userinfo = await self._parse_id_token(token, nonce=nonce)
747 userinfo = await self._parse_id_token(token, nonce=session_data.nonce)
669748 except Exception as e:
670749 logger.exception("Invalid id_token")
671750 self._sso_handler.render_error(request, "invalid_token", str(e))
672751 return
673752
674753 # first check if we're doing a UIA
675 if ui_auth_session_id:
754 if session_data.ui_auth_session_id:
676755 try:
677756 remote_user_id = self._remote_id_from_userinfo(userinfo)
678757 except Exception as e:
681760 return
682761
683762 return await self._sso_handler.complete_sso_ui_auth_request(
684 self._auth_provider_id, remote_user_id, ui_auth_session_id, request
763 self.idp_id, remote_user_id, session_data.ui_auth_session_id, request
685764 )
686765
687766 # otherwise, it's a login
689768 # Call the mapper to register/login the user
690769 try:
691770 await self._complete_oidc_login(
692 userinfo, token, request, client_redirect_url
771 userinfo, token, request, session_data.client_redirect_url
693772 )
694773 except MappingException as e:
695774 logger.exception("Could not map user")
696775 self._sso_handler.render_error(request, "mapping_error", str(e))
697
698 def _generate_oidc_session_token(
699 self,
700 state: str,
701 nonce: str,
702 client_redirect_url: str,
703 ui_auth_session_id: Optional[str],
704 duration_in_ms: int = (60 * 60 * 1000),
705 ) -> str:
706 """Generates a signed token storing data about an OIDC session.
707
708 When Synapse initiates an authorization flow, it creates a random state
709 and a random nonce. Those parameters are given to the provider and
710 should be verified when the client comes back from the provider.
711 It is also used to store the client_redirect_url, which is used to
712 complete the SSO login flow.
713
714 Args:
715 state: The ``state`` parameter passed to the OIDC provider.
716 nonce: The ``nonce`` parameter passed to the OIDC provider.
717 client_redirect_url: The URL the client gave when it initiated the
718 flow.
719 ui_auth_session_id: The session ID of the ongoing UI Auth (or
720 None if this is a login).
721 duration_in_ms: An optional duration for the token in milliseconds.
722 Defaults to an hour.
723
724 Returns:
725 A signed macaroon token with the session information.
726 """
727 macaroon = pymacaroons.Macaroon(
728 location=self._server_name, identifier="key", key=self._macaroon_secret_key,
729 )
730 macaroon.add_first_party_caveat("gen = 1")
731 macaroon.add_first_party_caveat("type = session")
732 macaroon.add_first_party_caveat("state = %s" % (state,))
733 macaroon.add_first_party_caveat("nonce = %s" % (nonce,))
734 macaroon.add_first_party_caveat(
735 "client_redirect_url = %s" % (client_redirect_url,)
736 )
737 if ui_auth_session_id:
738 macaroon.add_first_party_caveat(
739 "ui_auth_session_id = %s" % (ui_auth_session_id,)
740 )
741 now = self.clock.time_msec()
742 expiry = now + duration_in_ms
743 macaroon.add_first_party_caveat("time < %d" % (expiry,))
744
745 return macaroon.serialize()
746
747 def _verify_oidc_session_token(
748 self, session: bytes, state: str
749 ) -> Tuple[str, str, Optional[str]]:
750 """Verifies and extract an OIDC session token.
751
752 This verifies that a given session token was issued by this homeserver
753 and extract the nonce and client_redirect_url caveats.
754
755 Args:
756 session: The session token to verify
757 state: The state the OIDC provider gave back
758
759 Returns:
760 The nonce, client_redirect_url, and ui_auth_session_id for this session
761 """
762 macaroon = pymacaroons.Macaroon.deserialize(session)
763
764 v = pymacaroons.Verifier()
765 v.satisfy_exact("gen = 1")
766 v.satisfy_exact("type = session")
767 v.satisfy_exact("state = %s" % (state,))
768 v.satisfy_general(lambda c: c.startswith("nonce = "))
769 v.satisfy_general(lambda c: c.startswith("client_redirect_url = "))
770 # Sometimes there's a UI auth session ID, it seems to be OK to attempt
771 # to always satisfy this.
772 v.satisfy_general(lambda c: c.startswith("ui_auth_session_id = "))
773 v.satisfy_general(self._verify_expiry)
774
775 v.verify(macaroon, self._macaroon_secret_key)
776
777 # Extract the `nonce`, `client_redirect_url`, and maybe the
778 # `ui_auth_session_id` from the token.
779 nonce = self._get_value_from_macaroon(macaroon, "nonce")
780 client_redirect_url = self._get_value_from_macaroon(
781 macaroon, "client_redirect_url"
782 )
783 try:
784 ui_auth_session_id = self._get_value_from_macaroon(
785 macaroon, "ui_auth_session_id"
786 ) # type: Optional[str]
787 except ValueError:
788 ui_auth_session_id = None
789
790 return nonce, client_redirect_url, ui_auth_session_id
791
792 def _get_value_from_macaroon(self, macaroon: pymacaroons.Macaroon, key: str) -> str:
793 """Extracts a caveat value from a macaroon token.
794
795 Args:
796 macaroon: the token
797 key: the key of the caveat to extract
798
799 Returns:
800 The extracted value
801
802 Raises:
803 Exception: if the caveat was not in the macaroon
804 """
805 prefix = key + " = "
806 for caveat in macaroon.caveats:
807 if caveat.caveat_id.startswith(prefix):
808 return caveat.caveat_id[len(prefix) :]
809 raise ValueError("No %s caveat in macaroon" % (key,))
810
811 def _verify_expiry(self, caveat: str) -> bool:
812 prefix = "time < "
813 if not caveat.startswith(prefix):
814 return False
815 expiry = int(caveat[len(prefix) :])
816 now = self.clock.time_msec()
817 return now < expiry
818776
819777 async def _complete_oidc_login(
820778 self,
892850 # and attempt to match it.
893851 attributes = await oidc_response_to_user_attributes(failures=0)
894852
895 user_id = UserID(attributes.localpart, self.server_name).to_string()
896 users = await self.store.get_users_by_id_case_insensitive(user_id)
853 user_id = UserID(attributes.localpart, self._server_name).to_string()
854 users = await self._store.get_users_by_id_case_insensitive(user_id)
897855 if users:
898856 # If an existing matrix ID is returned, then use it.
899857 if len(users) == 1:
922880 extra_attributes = await get_extra_attributes(userinfo, token)
923881
924882 await self._sso_handler.complete_sso_login_request(
925 self._auth_provider_id,
883 self.idp_id,
926884 remote_user_id,
927885 request,
928886 client_redirect_url,
943901 # Some OIDC providers use integer IDs, but Synapse expects external IDs
944902 # to be strings.
945903 return str(remote_user_id)
904
905
906 class OidcSessionTokenGenerator:
907 """Methods for generating and checking OIDC Session cookies."""
908
909 def __init__(self, hs: "HomeServer"):
910 self._clock = hs.get_clock()
911 self._server_name = hs.hostname
912 self._macaroon_secret_key = hs.config.key.macaroon_secret_key
913
914 def generate_oidc_session_token(
915 self,
916 state: str,
917 session_data: "OidcSessionData",
918 duration_in_ms: int = (60 * 60 * 1000),
919 ) -> str:
920 """Generates a signed token storing data about an OIDC session.
921
922 When Synapse initiates an authorization flow, it creates a random state
923 and a random nonce. Those parameters are given to the provider and
924 should be verified when the client comes back from the provider.
925 It is also used to store the client_redirect_url, which is used to
926 complete the SSO login flow.
927
928 Args:
929 state: The ``state`` parameter passed to the OIDC provider.
930 session_data: data to include in the session token.
931 duration_in_ms: An optional duration for the token in milliseconds.
932 Defaults to an hour.
933
934 Returns:
935 A signed macaroon token with the session information.
936 """
937 macaroon = pymacaroons.Macaroon(
938 location=self._server_name, identifier="key", key=self._macaroon_secret_key,
939 )
940 macaroon.add_first_party_caveat("gen = 1")
941 macaroon.add_first_party_caveat("type = session")
942 macaroon.add_first_party_caveat("state = %s" % (state,))
943 macaroon.add_first_party_caveat("idp_id = %s" % (session_data.idp_id,))
944 macaroon.add_first_party_caveat("nonce = %s" % (session_data.nonce,))
945 macaroon.add_first_party_caveat(
946 "client_redirect_url = %s" % (session_data.client_redirect_url,)
947 )
948 if session_data.ui_auth_session_id:
949 macaroon.add_first_party_caveat(
950 "ui_auth_session_id = %s" % (session_data.ui_auth_session_id,)
951 )
952 now = self._clock.time_msec()
953 expiry = now + duration_in_ms
954 macaroon.add_first_party_caveat("time < %d" % (expiry,))
955
956 return macaroon.serialize()
957
958 def verify_oidc_session_token(
959 self, session: bytes, state: str
960 ) -> "OidcSessionData":
961 """Verifies and extract an OIDC session token.
962
963 This verifies that a given session token was issued by this homeserver
964 and extract the nonce and client_redirect_url caveats.
965
966 Args:
967 session: The session token to verify
968 state: The state the OIDC provider gave back
969
970 Returns:
971 The data extracted from the session cookie
972
973 Raises:
974 ValueError if an expected caveat is missing from the macaroon.
975 """
976 macaroon = pymacaroons.Macaroon.deserialize(session)
977
978 v = pymacaroons.Verifier()
979 v.satisfy_exact("gen = 1")
980 v.satisfy_exact("type = session")
981 v.satisfy_exact("state = %s" % (state,))
982 v.satisfy_general(lambda c: c.startswith("nonce = "))
983 v.satisfy_general(lambda c: c.startswith("idp_id = "))
984 v.satisfy_general(lambda c: c.startswith("client_redirect_url = "))
985 # Sometimes there's a UI auth session ID, it seems to be OK to attempt
986 # to always satisfy this.
987 v.satisfy_general(lambda c: c.startswith("ui_auth_session_id = "))
988 v.satisfy_general(self._verify_expiry)
989
990 v.verify(macaroon, self._macaroon_secret_key)
991
992 # Extract the session data from the token.
993 nonce = self._get_value_from_macaroon(macaroon, "nonce")
994 idp_id = self._get_value_from_macaroon(macaroon, "idp_id")
995 client_redirect_url = self._get_value_from_macaroon(
996 macaroon, "client_redirect_url"
997 )
998 try:
999 ui_auth_session_id = self._get_value_from_macaroon(
1000 macaroon, "ui_auth_session_id"
1001 ) # type: Optional[str]
1002 except ValueError:
1003 ui_auth_session_id = None
1004
1005 return OidcSessionData(
1006 nonce=nonce,
1007 idp_id=idp_id,
1008 client_redirect_url=client_redirect_url,
1009 ui_auth_session_id=ui_auth_session_id,
1010 )
1011
1012 def _get_value_from_macaroon(self, macaroon: pymacaroons.Macaroon, key: str) -> str:
1013 """Extracts a caveat value from a macaroon token.
1014
1015 Args:
1016 macaroon: the token
1017 key: the key of the caveat to extract
1018
1019 Returns:
1020 The extracted value
1021
1022 Raises:
1023 ValueError: if the caveat was not in the macaroon
1024 """
1025 prefix = key + " = "
1026 for caveat in macaroon.caveats:
1027 if caveat.caveat_id.startswith(prefix):
1028 return caveat.caveat_id[len(prefix) :]
1029 raise ValueError("No %s caveat in macaroon" % (key,))
1030
1031 def _verify_expiry(self, caveat: str) -> bool:
1032 prefix = "time < "
1033 if not caveat.startswith(prefix):
1034 return False
1035 expiry = int(caveat[len(prefix) :])
1036 now = self._clock.time_msec()
1037 return now < expiry
1038
1039
1040 @attr.s(frozen=True, slots=True)
1041 class OidcSessionData:
1042 """The attributes which are stored in a OIDC session cookie"""
1043
1044 # the Identity Provider being used
1045 idp_id = attr.ib(type=str)
1046
1047 # The `nonce` parameter passed to the OIDC provider.
1048 nonce = attr.ib(type=str)
1049
1050 # The URL the client gave when it initiated the flow. ("" if this is a UI Auth)
1051 client_redirect_url = attr.ib(type=str)
1052
1053 # The session ID of the ongoing UI Auth (None if this is a login)
1054 ui_auth_session_id = attr.ib(type=Optional[str], default=None)
9461055
9471056
9481057 UserAttributeDict = TypedDict(
155155 except HttpResponseException as e:
156156 raise e.to_synapse_error()
157157
158 return result["displayname"]
158 return result.get("displayname")
159159
160160 async def set_displayname(
161161 self,
245245 except HttpResponseException as e:
246246 raise e.to_synapse_error()
247247
248 return result["avatar_url"]
248 return result.get("avatar_url")
249249
250250 async def set_avatar_url(
251251 self,
285285 400, "Avatar URL is too long (max %i)" % (MAX_AVATAR_URL_LEN,)
286286 )
287287
288 avatar_url_to_set = new_avatar_url # type: Optional[str]
289 if new_avatar_url == "":
290 avatar_url_to_set = None
291
288292 # Same like set_displayname
289293 if by_admin:
290294 requester = create_requester(
291295 target_user, authenticated_entity=requester.authenticated_entity
292296 )
293297
294 await self.store.set_profile_avatar_url(target_user.localpart, new_avatar_url)
298 await self.store.set_profile_avatar_url(
299 target_user.localpart, avatar_url_to_set
300 )
295301
296302 if self.hs.config.user_directory_search_all_users:
297303 profile = await self.store.get_profileinfo(target_user.localpart)
3030 super().__init__(hs)
3131 self.server_name = hs.config.server_name
3232 self.store = hs.get_datastore()
33 self.account_data_handler = hs.get_account_data_handler()
3334 self.read_marker_linearizer = Linearizer(name="read_marker")
34 self.notifier = hs.get_notifier()
3535
3636 async def received_client_read_marker(
3737 self, room_id: str, user_id: str, event_id: str
5858
5959 if should_update:
6060 content = {"event_id": event_id}
61 max_id = await self.store.add_account_data_to_room(
61 await self.account_data_handler.add_account_data_to_room(
6262 user_id, room_id, "m.fully_read", content
6363 )
64 self.notifier.on_new_event("account_data_key", max_id, users=[user_id])
3131 self.server_name = hs.config.server_name
3232 self.store = hs.get_datastore()
3333 self.hs = hs
34 self.federation = hs.get_federation_sender()
35 hs.get_federation_registry().register_edu_handler(
36 "m.receipt", self._received_remote_receipt
37 )
34
35 # We only need to poke the federation sender explicitly if its on the
36 # same instance. Other federation sender instances will get notified by
37 # `synapse.app.generic_worker.FederationSenderHandler` when it sees it
38 # in the receipts stream.
39 self.federation_sender = None
40 if hs.should_send_federation():
41 self.federation_sender = hs.get_federation_sender()
42
43 # If we can handle the receipt EDUs we do so, otherwise we route them
44 # to the appropriate worker.
45 if hs.get_instance_name() in hs.config.worker.writers.receipts:
46 hs.get_federation_registry().register_edu_handler(
47 "m.receipt", self._received_remote_receipt
48 )
49 else:
50 hs.get_federation_registry().register_instances_for_edu(
51 "m.receipt", hs.config.worker.writers.receipts,
52 )
53
3854 self.clock = self.hs.get_clock()
3955 self.state = hs.get_state_handler()
4056
124140 if not is_new:
125141 return
126142
127 await self.federation.send_read_receipt(receipt)
143 if self.federation_sender:
144 await self.federation_sender.send_read_receipt(receipt)
128145
129146
130147 class ReceiptEventSource:
3737 from synapse.api.room_versions import KNOWN_ROOM_VERSIONS, RoomVersion
3838 from synapse.events import EventBase
3939 from synapse.events.utils import copy_power_levels_contents
40 from synapse.http.endpoint import parse_and_validate_server_name
4140 from synapse.storage.state import StateFilter
4241 from synapse.types import (
4342 JsonDict,
5453 from synapse.util import stringutils
5554 from synapse.util.async_helpers import Linearizer
5655 from synapse.util.caches.response_cache import ResponseCache
56 from synapse.util.stringutils import parse_and_validate_server_name
5757 from synapse.visibility import filter_events_for_client
5858
5959 from ._base import BaseHandler
364364 creation_content = {
365365 "room_version": new_room_version.identifier,
366366 "predecessor": {"room_id": old_room_id, "event_id": tombstone_event_id},
367 }
367 } # type: JsonDict
368368
369369 # Check if old room was non-federatable
370370
6262 self.registration_handler = hs.get_registration_handler()
6363 self.profile_handler = hs.get_profile_handler()
6464 self.event_creation_handler = hs.get_event_creation_handler()
65 self.account_data_handler = hs.get_account_data_handler()
6566
6667 self.member_linearizer = Linearizer(name="member")
6768
252253 direct_rooms[key].append(new_room_id)
253254
254255 # Save back to user's m.direct account data
255 await self.store.add_account_data_for_user(
256 await self.account_data_handler.add_account_data_for_user(
256257 user_id, AccountDataTypes.DIRECT, direct_rooms
257258 )
258259 break
262263
263264 # Copy each room tag to the new room
264265 for tag, tag_content in room_tags.items():
265 await self.store.add_tag_to_room(user_id, new_room_id, tag, tag_content)
266 await self.account_data_handler.add_tag_to_room(
267 user_id, new_room_id, tag, tag_content
268 )
266269
267270 async def update_membership(
268271 self,
7272 )
7373
7474 # identifier for the external_ids table
75 self._auth_provider_id = "saml"
75 self.idp_id = "saml"
76
77 # user-facing name of this auth provider
78 self.idp_name = "SAML"
79
80 # we do not currently support icons for SAML auth, but this is required by
81 # the SsoIdentityProvider protocol type.
82 self.idp_icon = None
7683
7784 # a map from saml session id to Saml2SessionData object
7885 self._outstanding_requests_dict = {} # type: Dict[str, Saml2SessionData]
7986
8087 self._sso_handler = hs.get_sso_handler()
81
82 def handle_redirect_request(
83 self, client_redirect_url: bytes, ui_auth_session_id: Optional[str] = None
84 ) -> bytes:
88 self._sso_handler.register_identity_provider(self)
89
90 async def handle_redirect_request(
91 self,
92 request: SynapseRequest,
93 client_redirect_url: Optional[bytes],
94 ui_auth_session_id: Optional[str] = None,
95 ) -> str:
8596 """Handle an incoming request to /login/sso/redirect
8697
8798 Args:
99 request: the incoming HTTP request
88100 client_redirect_url: the URL that we should redirect the
89 client to when everything is done
101 client to after login (or None for UI Auth).
90102 ui_auth_session_id: The session ID of the ongoing UI Auth (or
91103 None if this is a login).
92104
93105 Returns:
94106 URL to redirect to
95107 """
108 if not client_redirect_url:
109 # Some SAML identity providers (e.g. Google) require a
110 # RelayState parameter on requests, so pass in a dummy redirect URL
111 # (which will never get used).
112 client_redirect_url = b"unused"
113
96114 reqid, info = self._saml_client.prepare_for_authenticate(
97115 entityid=self._saml_idp_entityid, relay_state=client_redirect_url
98116 )
209227 return
210228
211229 return await self._sso_handler.complete_sso_ui_auth_request(
212 self._auth_provider_id,
230 self.idp_id,
213231 remote_user_id,
214232 current_session.ui_auth_session_id,
215233 request,
305323 return None
306324
307325 await self._sso_handler.complete_sso_login_request(
308 self._auth_provider_id,
326 self.idp_id,
309327 remote_user_id,
310328 request,
311329 client_redirect_url,
1111 # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
1212 # See the License for the specific language governing permissions and
1313 # limitations under the License.
14 import abc
1415 import logging
15 from typing import TYPE_CHECKING, Awaitable, Callable, Dict, List, Optional
16 from typing import TYPE_CHECKING, Awaitable, Callable, Dict, List, Mapping, Optional
17 from urllib.parse import urlencode
1618
1719 import attr
18 from typing_extensions import NoReturn
20 from typing_extensions import NoReturn, Protocol
1921
2022 from twisted.web.http import Request
2123
22 from synapse.api.errors import RedirectException, SynapseError
24 from synapse.api.constants import LoginType
25 from synapse.api.errors import Codes, RedirectException, SynapseError
26 from synapse.handlers.ui_auth import UIAuthSessionDataConstants
27 from synapse.http import get_request_user_agent
2328 from synapse.http.server import respond_with_html
2429 from synapse.http.site import SynapseRequest
2530 from synapse.types import JsonDict, UserID, contains_invalid_mxid_characters
3742
3843 Note that the msg that is raised is shown to end-users.
3944 """
45
46
47 class SsoIdentityProvider(Protocol):
48 """Abstract base class to be implemented by SSO Identity Providers
49
50 An Identity Provider, or IdP, is an external HTTP service which authenticates a user
51 to say whether they should be allowed to log in, or perform a given action.
52
53 Synapse supports various implementations of IdPs, including OpenID Connect, SAML,
54 and CAS.
55
56 The main entry point is `handle_redirect_request`, which should return a URI to
57 redirect the user's browser to the IdP's authentication page.
58
59 Each IdP should be registered with the SsoHandler via
60 `hs.get_sso_handler().register_identity_provider()`, so that requests to
61 `/_matrix/client/r0/login/sso/redirect` can be correctly dispatched.
62 """
63
64 @property
65 @abc.abstractmethod
66 def idp_id(self) -> str:
67 """A unique identifier for this SSO provider
68
69 Eg, "saml", "cas", "github"
70 """
71
72 @property
73 @abc.abstractmethod
74 def idp_name(self) -> str:
75 """User-facing name for this provider"""
76
77 @property
78 def idp_icon(self) -> Optional[str]:
79 """Optional MXC URI for user-facing icon"""
80 return None
81
82 @abc.abstractmethod
83 async def handle_redirect_request(
84 self,
85 request: SynapseRequest,
86 client_redirect_url: Optional[bytes],
87 ui_auth_session_id: Optional[str] = None,
88 ) -> str:
89 """Handle an incoming request to /login/sso/redirect
90
91 Args:
92 request: the incoming HTTP request
93 client_redirect_url: the URL that we should redirect the
94 client to after login (or None for UI Auth).
95 ui_auth_session_id: The session ID of the ongoing UI Auth (or
96 None if this is a login).
97
98 Returns:
99 URL to redirect to
100 """
101 raise NotImplementedError()
40102
41103
42104 @attr.s
90152 self._store = hs.get_datastore()
91153 self._server_name = hs.hostname
92154 self._registration_handler = hs.get_registration_handler()
155 self._auth_handler = hs.get_auth_handler()
93156 self._error_template = hs.config.sso_error_template
94 self._auth_handler = hs.get_auth_handler()
157 self._bad_user_template = hs.config.sso_auth_bad_user_template
158
159 # The following template is shown after a successful user interactive
160 # authentication session. It tells the user they can close the window.
161 self._sso_auth_success_template = hs.config.sso_auth_success_template
95162
96163 # a lock on the mappings
97164 self._mapping_lock = Linearizer(name="sso_user_mapping", clock=hs.get_clock())
98165
99166 # a map from session id to session data
100167 self._username_mapping_sessions = {} # type: Dict[str, UsernameMappingSession]
168
169 # map from idp_id to SsoIdentityProvider
170 self._identity_providers = {} # type: Dict[str, SsoIdentityProvider]
171
172 def register_identity_provider(self, p: SsoIdentityProvider):
173 p_id = p.idp_id
174 assert p_id not in self._identity_providers
175 self._identity_providers[p_id] = p
176
177 def get_identity_providers(self) -> Mapping[str, SsoIdentityProvider]:
178 """Get the configured identity providers"""
179 return self._identity_providers
180
181 async def get_identity_providers_for_user(
182 self, user_id: str
183 ) -> Mapping[str, SsoIdentityProvider]:
184 """Get the SsoIdentityProviders which a user has used
185
186 Given a user id, get the identity providers that that user has used to log in
187 with in the past (and thus could use to re-identify themselves for UI Auth).
188
189 Args:
190 user_id: MXID of user to look up
191
192 Raises:
193 a map of idp_id to SsoIdentityProvider
194 """
195 external_ids = await self._store.get_external_ids_by_user(user_id)
196
197 valid_idps = {}
198 for idp_id, _ in external_ids:
199 idp = self._identity_providers.get(idp_id)
200 if not idp:
201 logger.warning(
202 "User %r has an SSO mapping for IdP %r, but this is no longer "
203 "configured.",
204 user_id,
205 idp_id,
206 )
207 else:
208 valid_idps[idp_id] = idp
209
210 return valid_idps
101211
102212 def render_error(
103213 self,
122232 error=error, error_description=error_description
123233 )
124234 respond_with_html(request, code, html)
235
236 async def handle_redirect_request(
237 self, request: SynapseRequest, client_redirect_url: bytes,
238 ) -> str:
239 """Handle a request to /login/sso/redirect
240
241 Args:
242 request: incoming HTTP request
243 client_redirect_url: the URL that we should redirect the
244 client to after login.
245
246 Returns:
247 the URI to redirect to
248 """
249 if not self._identity_providers:
250 raise SynapseError(
251 400, "Homeserver not configured for SSO.", errcode=Codes.UNRECOGNIZED
252 )
253
254 # if we only have one auth provider, redirect to it directly
255 if len(self._identity_providers) == 1:
256 ap = next(iter(self._identity_providers.values()))
257 return await ap.handle_redirect_request(request, client_redirect_url)
258
259 # otherwise, redirect to the IDP picker
260 return "/_synapse/client/pick_idp?" + urlencode(
261 (("redirectUrl", client_redirect_url),)
262 )
125263
126264 async def get_sso_user_by_remote_user_id(
127265 self, auth_provider_id: str, remote_user_id: str
267405 attributes,
268406 auth_provider_id,
269407 remote_user_id,
270 request.get_user_agent(""),
408 get_request_user_agent(request),
271409 request.getClientIP(),
272410 )
273411
450588 auth_provider_id, remote_user_id,
451589 )
452590
591 user_id_to_verify = await self._auth_handler.get_session_data(
592 ui_auth_session_id, UIAuthSessionDataConstants.REQUEST_USER_ID
593 ) # type: str
594
453595 if not user_id:
454596 logger.warning(
455597 "Remote user %s/%s has not previously logged in here: UIA will fail",
456598 auth_provider_id,
457599 remote_user_id,
458600 )
459 # Let the UIA flow handle this the same as if they presented creds for a
460 # different user.
461 user_id = ""
462
463 await self._auth_handler.complete_sso_ui_auth(
464 user_id, ui_auth_session_id, request
465 )
601 elif user_id != user_id_to_verify:
602 logger.warning(
603 "Remote user %s/%s mapped onto incorrect user %s: UIA will fail",
604 auth_provider_id,
605 remote_user_id,
606 user_id,
607 )
608 else:
609 # success!
610 # Mark the stage of the authentication as successful.
611 await self._store.mark_ui_auth_stage_complete(
612 ui_auth_session_id, LoginType.SSO, user_id
613 )
614
615 # Render the HTML confirmation page and return.
616 html = self._sso_auth_success_template
617 respond_with_html(request, 200, html)
618 return
619
620 # the user_id didn't match: mark the stage of the authentication as unsuccessful
621 await self._store.mark_ui_auth_stage_complete(
622 ui_auth_session_id, LoginType.SSO, ""
623 )
624
625 # render an error page.
626 html = self._bad_user_template.render(
627 server_name=self._server_name, user_id_to_verify=user_id_to_verify,
628 )
629 respond_with_html(request, 200, html)
466630
467631 async def check_username_availability(
468632 self, localpart: str, session_id: str,
533697 attributes,
534698 session.auth_provider_id,
535699 session.remote_user_id,
536 request.get_user_agent(""),
700 get_request_user_agent(request),
537701 request.getClientIP(),
538702 )
539703
1919 """
2020
2121 from synapse.handlers.ui_auth.checkers import INTERACTIVE_AUTH_CHECKERS # noqa: F401
22
23
24 class UIAuthSessionDataConstants:
25 """Constants for use with AuthHandler.set_session_data"""
26
27 # used during registration and password reset to store a hashed copy of the
28 # password, so that the client does not need to submit it each time.
29 PASSWORD_HASH = "password_hash"
30
31 # used during registration to store the mxid of the registered user
32 REGISTERED_USER_ID = "registered_user_id"
33
34 # used by validate_user_via_ui_auth to store the mxid of the user we are validating
35 # for.
36 REQUEST_USER_ID = "request_user_id"
1616
1717 from twisted.internet import task
1818 from twisted.web.client import FileBodyProducer
19 from twisted.web.iweb import IRequest
1920
2021 from synapse.api.errors import SynapseError
2122
4950 FileBodyProducer.stopProducing(self)
5051 except task.TaskStopped:
5152 pass
53
54
55 def get_request_user_agent(request: IRequest, default: str = "") -> str:
56 """Return the last User-Agent header, or the given default.
57 """
58 # There could be raw utf-8 bytes in the User-Agent header.
59
60 # N.B. if you don't do this, the logger explodes cryptically
61 # with maximum recursion trying to log errors about
62 # the charset problem.
63 # c.f. https://github.com/matrix-org/synapse/issues/3471
64
65 h = request.getHeader(b"User-Agent")
66 return h.decode("ascii", "replace") if h else default
3131
3232 import treq
3333 from canonicaljson import encode_canonical_json
34 from netaddr import IPAddress, IPSet
34 from netaddr import AddrFormatError, IPAddress, IPSet
3535 from prometheus_client import Counter
3636 from zope.interface import implementer, provider
3737
260260
261261 try:
262262 ip_address = IPAddress(h.hostname)
263
263 except AddrFormatError:
264 # Not an IP
265 pass
266 else:
264267 if check_against_blacklist(
265268 ip_address, self._ip_whitelist, self._ip_blacklist
266269 ):
267270 logger.info("Blocking access to %s due to blacklist" % (ip_address,))
268271 e = SynapseError(403, "IP address blocked by IP blacklist entry")
269272 return defer.fail(Failure(e))
270 except Exception:
271 # Not an IP
272 pass
273273
274274 return self._agent.request(
275275 method, uri, headers=headers, bodyProducer=bodyProducer
723723 read_body_with_max_size(response, output_stream, max_size)
724724 )
725725 except BodyExceededMaxSize:
726 SynapseError(
726 raise SynapseError(
727727 502,
728728 "Requested file is too large > %r bytes" % (max_size,),
729729 Codes.TOO_LARGE,
765765 self.max_size = max_size
766766
767767 def dataReceived(self, data: bytes) -> None:
768 # If the deferred was called, bail early.
769 if self.deferred.called:
770 return
771
768772 self.stream.write(data)
769773 self.length += len(data)
774 # The first time the maximum size is exceeded, error and cancel the
775 # connection. dataReceived might be called again if data was received
776 # in the meantime.
770777 if self.max_size is not None and self.length >= self.max_size:
771778 self.deferred.errback(BodyExceededMaxSize())
772 self.deferred = defer.Deferred()
773779 self.transport.loseConnection()
774780
775781 def connectionLost(self, reason: Failure) -> None:
782 # If the maximum size was already exceeded, there's nothing to do.
783 if self.deferred.called:
784 return
785
776786 if reason.check(ResponseDone):
777787 self.deferred.callback(self.length)
778788 elif reason.check(PotentialDataLoss):
+0
-79
synapse/http/endpoint.py less more
0 # -*- coding: utf-8 -*-
1 # Copyright 2014-2016 OpenMarket Ltd
2 #
3 # Licensed under the Apache License, Version 2.0 (the "License");
4 # you may not use this file except in compliance with the License.
5 # You may obtain a copy of the License at
6 #
7 # http://www.apache.org/licenses/LICENSE-2.0
8 #
9 # Unless required by applicable law or agreed to in writing, software
10 # distributed under the License is distributed on an "AS IS" BASIS,
11 # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 # See the License for the specific language governing permissions and
13 # limitations under the License.
14 import logging
15 import re
16
17 logger = logging.getLogger(__name__)
18
19
20 def parse_server_name(server_name):
21 """Split a server name into host/port parts.
22
23 Args:
24 server_name (str): server name to parse
25
26 Returns:
27 Tuple[str, int|None]: host/port parts.
28
29 Raises:
30 ValueError if the server name could not be parsed.
31 """
32 try:
33 if server_name[-1] == "]":
34 # ipv6 literal, hopefully
35 return server_name, None
36
37 domain_port = server_name.rsplit(":", 1)
38 domain = domain_port[0]
39 port = int(domain_port[1]) if domain_port[1:] else None
40 return domain, port
41 except Exception:
42 raise ValueError("Invalid server name '%s'" % server_name)
43
44
45 VALID_HOST_REGEX = re.compile("\\A[0-9a-zA-Z.-]+\\Z")
46
47
48 def parse_and_validate_server_name(server_name):
49 """Split a server name into host/port parts and do some basic validation.
50
51 Args:
52 server_name (str): server name to parse
53
54 Returns:
55 Tuple[str, int|None]: host/port parts.
56
57 Raises:
58 ValueError if the server name could not be parsed.
59 """
60 host, port = parse_server_name(server_name)
61
62 # these tests don't need to be bulletproof as we'll find out soon enough
63 # if somebody is giving us invalid data. What we *do* need is to be sure
64 # that nobody is sneaking IP literals in that look like hostnames, etc.
65
66 # look for ipv6 literals
67 if host[0] == "[":
68 if host[-1] != "]":
69 raise ValueError("Mismatched [...] in server name '%s'" % (server_name,))
70 return host, port
71
72 # otherwise it should only be alphanumerics.
73 if not VALID_HOST_REGEX.match(host):
74 raise ValueError(
75 "Server name '%s' contains invalid characters" % (server_name,)
76 )
77
78 return host, port
101101 pool=self._pool,
102102 contextFactory=tls_client_options_factory,
103103 ),
104 self._reactor,
105104 ip_blacklist=ip_blacklist,
106105 ),
107106 user_agent=self.user_agent,
173173 d = timeout_deferred(d, timeout=timeout_sec, reactor=reactor)
174174
175175 body = await make_deferred_yieldable(d)
176 except ValueError as e:
177 # The JSON content was invalid.
178 logger.warning(
179 "{%s} [%s] Failed to parse JSON response - %s %s",
180 request.txn_id,
181 request.destination,
182 request.method,
183 request.uri.decode("ascii"),
184 )
185 raise RequestSendFailed(e, can_retry=False) from e
176186 except defer.TimeoutError as e:
177187 logger.warning(
178188 "{%s} [%s] Timed out reading response - %s %s",
985995 logger.warning(
986996 "{%s} [%s] %s", request.txn_id, request.destination, msg,
987997 )
988 SynapseError(502, msg, Codes.TOO_LARGE)
998 raise SynapseError(502, msg, Codes.TOO_LARGE)
989999 except Exception as e:
9901000 logger.warning(
9911001 "{%s} [%s] Error reading response: %s",
1919 from twisted.web.server import Request, Site
2020
2121 from synapse.config.server import ListenerConfig
22 from synapse.http import redact_uri
22 from synapse.http import get_request_user_agent, redact_uri
2323 from synapse.http.request_metrics import RequestMetrics, requests_counter
2424 from synapse.logging.context import LoggingContext, PreserveLoggingContext
2525 from synapse.types import Requester
111111 if isinstance(method, bytes):
112112 method = self.method.decode("ascii")
113113 return method
114
115 def get_user_agent(self, default: str) -> str:
116 """Return the last User-Agent header, or the given default.
117 """
118 user_agent = self.requestHeaders.getRawHeaders(b"User-Agent", [None])[-1]
119 if user_agent is None:
120 return default
121
122 return user_agent.decode("ascii", "replace")
123114
124115 def render(self, resrc):
125116 # this is called once a Resource has been found to serve the request; in our
291282 # and can see that we're doing something wrong.
292283 authenticated_entity = repr(self.requester) # type: ignore[unreachable]
293284
294 # ...or could be raw utf-8 bytes in the User-Agent header.
295 # N.B. if you don't do this, the logger explodes cryptically
296 # with maximum recursion trying to log errors about
297 # the charset problem.
298 # c.f. https://github.com/matrix-org/synapse/issues/3471
299 user_agent = self.get_user_agent("-")
285 user_agent = get_request_user_agent(self, "-")
300286
301287 code = str(self.code)
302288 if not self.finished:
251251 "scope",
252252 ]
253253
254 def __init__(self, name=None, parent_context=None, request=None) -> None:
254 def __init__(
255 self,
256 name: Optional[str] = None,
257 parent_context: "Optional[LoggingContext]" = None,
258 request: Optional[str] = None,
259 ) -> None:
255260 self.previous_context = current_context()
256261 self.name = name
257262
535540 def __init__(self, request: str = ""):
536541 self._default_request = request
537542
538 def filter(self, record) -> Literal[True]:
543 def filter(self, record: logging.LogRecord) -> Literal[True]:
539544 """Add each fields from the logging contexts to the record.
540545 Returns:
541546 True to include the record in the log output.
542547 """
543548 context = current_context()
544 record.request = self._default_request
549 record.request = self._default_request # type: ignore
545550
546551 # context should never be None, but if it somehow ends up being, then
547552 # we end up in a death spiral of infinite loops, so let's check, for
548553 # robustness' sake.
549554 if context is not None:
550555 # Logging is interested in the request.
551 record.request = context.request
556 record.request = context.request # type: ignore
552557
553558 return True
554559
615620 return current
616621
617622
618 def nested_logging_context(
619 suffix: str, parent_context: Optional[LoggingContext] = None
620 ) -> LoggingContext:
623 def nested_logging_context(suffix: str) -> LoggingContext:
621624 """Creates a new logging context as a child of another.
622625
623626 The nested logging context will have a 'request' made up of the parent context's
631634 # ... do stuff
632635
633636 Args:
634 suffix (str): suffix to add to the parent context's 'request'.
635 parent_context (LoggingContext|None): parent context. Will use the current context
636 if None.
637 suffix: suffix to add to the parent context's 'request'.
637638
638639 Returns:
639640 LoggingContext: new logging context.
640641 """
641 if parent_context is not None:
642 context = parent_context # type: LoggingContextOrSentinel
642 curr_context = current_context()
643 if not curr_context:
644 logger.warning(
645 "Starting nested logging context from sentinel context: metrics will be lost"
646 )
647 parent_context = None
648 prefix = ""
643649 else:
644 context = current_context()
645 return LoggingContext(
646 parent_context=context, request=str(context.request) + "-" + suffix
647 )
650 assert isinstance(curr_context, LoggingContext)
651 parent_context = curr_context
652 prefix = str(parent_context.request)
653 return LoggingContext(parent_context=parent_context, request=prefix + "-" + suffix)
648654
649655
650656 def preserve_fn(f):
821827 Deferred: A Deferred which fires a callback with the result of `f`, or an
822828 errback if `f` throws an exception.
823829 """
824 logcontext = current_context()
830 curr_context = current_context()
831 if not curr_context:
832 logger.warning(
833 "Calling defer_to_threadpool from sentinel context: metrics will be lost"
834 )
835 parent_context = None
836 else:
837 assert isinstance(curr_context, LoggingContext)
838 parent_context = curr_context
825839
826840 def g():
827 with LoggingContext(parent_context=logcontext):
841 with LoggingContext(parent_context=parent_context):
828842 return f(*args, **kwargs)
829843
830844 return make_deferred_yieldable(threads.deferToThreadPool(reactor, threadpool, g))
395395
396396 Will wake up all listeners for the given users and rooms.
397397 """
398 with PreserveLoggingContext():
399 with Measure(self.clock, "on_new_event"):
400 user_streams = set()
401
402 for user in users:
403 user_stream = self.user_to_user_stream.get(str(user))
404 if user_stream is not None:
405 user_streams.add(user_stream)
406
407 for room in rooms:
408 user_streams |= self.room_to_user_streams.get(room, set())
409
410 time_now_ms = self.clock.time_msec()
411 for user_stream in user_streams:
412 try:
413 user_stream.notify(stream_key, new_token, time_now_ms)
414 except Exception:
415 logger.exception("Failed to notify listener")
416
417 self.notify_replication()
418
419 # Notify appservices
420 self._notify_app_services_ephemeral(
421 stream_key, new_token, users,
422 )
398 with Measure(self.clock, "on_new_event"):
399 user_streams = set()
400
401 for user in users:
402 user_stream = self.user_to_user_stream.get(str(user))
403 if user_stream is not None:
404 user_streams.add(user_stream)
405
406 for room in rooms:
407 user_streams |= self.room_to_user_streams.get(room, set())
408
409 time_now_ms = self.clock.time_msec()
410 for user_stream in user_streams:
411 try:
412 user_stream.notify(stream_key, new_token, time_now_ms)
413 except Exception:
414 logger.exception("Failed to notify listener")
415
416 self.notify_replication()
417
418 # Notify appservices
419 self._notify_app_services_ephemeral(
420 stream_key, new_token, users,
421 )
423422
424423 def on_new_replication_data(self) -> None:
425424 """Used to inform replication listeners that something has happened
202202
203203 condition_cache = {} # type: Dict[str, bool]
204204
205 # If the event is not a state event check if any users ignore the sender.
206 if not event.is_state():
207 ignorers = await self.store.ignored_by(event.sender)
208 else:
209 ignorers = set()
210
205211 for uid, rules in rules_by_user.items():
206212 if event.sender == uid:
207213 continue
208214
209 if not event.is_state():
210 is_ignored = await self.store.is_ignored_by(event.sender, uid)
211 if is_ignored:
212 continue
215 if uid in ignorers:
216 continue
213217
214218 display_name = None
215219 profile_info = room_members.get(uid)
8585
8686 CONDITIONAL_REQUIREMENTS = {
8787 "matrix-synapse-ldap3": ["matrix-synapse-ldap3>=0.1"],
88 # we use execute_batch, which arrived in psycopg 2.7.
89 "postgres": ["psycopg2>=2.7"],
88 # we use execute_values with the fetch param, which arrived in psycopg 2.8.
89 "postgres": ["psycopg2>=2.8"],
9090 # ACME support is required to provision TLS certificates from authorities
9191 # that use the protocol, such as Let's Encrypt.
9292 "acme": [
1414
1515 from synapse.http.server import JsonResource
1616 from synapse.replication.http import (
17 account_data,
1718 devices,
1819 federation,
1920 login,
3940 presence.register_servlets(hs, self)
4041 membership.register_servlets(hs, self)
4142 streams.register_servlets(hs, self)
43 account_data.register_servlets(hs, self)
4244
4345 # The following can't currently be instantiated on workers.
4446 if hs.config.worker.worker_app is None:
176176
177177 @trace(opname="outgoing_replication_request")
178178 @outgoing_gauge.track_inprogress()
179 async def send_request(instance_name="master", **kwargs):
179 async def send_request(*, instance_name="master", **kwargs):
180180 if instance_name == local_instance_name:
181181 raise Exception("Trying to send HTTP request to self")
182182 if instance_name == "master":
0 # -*- coding: utf-8 -*-
1 # Copyright 2021 The Matrix.org Foundation C.I.C.
2 #
3 # Licensed under the Apache License, Version 2.0 (the "License");
4 # you may not use this file except in compliance with the License.
5 # You may obtain a copy of the License at
6 #
7 # http://www.apache.org/licenses/LICENSE-2.0
8 #
9 # Unless required by applicable law or agreed to in writing, software
10 # distributed under the License is distributed on an "AS IS" BASIS,
11 # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 # See the License for the specific language governing permissions and
13 # limitations under the License.
14
15 import logging
16
17 from synapse.http.servlet import parse_json_object_from_request
18 from synapse.replication.http._base import ReplicationEndpoint
19
20 logger = logging.getLogger(__name__)
21
22
23 class ReplicationUserAccountDataRestServlet(ReplicationEndpoint):
24 """Add user account data on the appropriate account data worker.
25
26 Request format:
27
28 POST /_synapse/replication/add_user_account_data/:user_id/:type
29
30 {
31 "content": { ... },
32 }
33
34 """
35
36 NAME = "add_user_account_data"
37 PATH_ARGS = ("user_id", "account_data_type")
38 CACHE = False
39
40 def __init__(self, hs):
41 super().__init__(hs)
42
43 self.handler = hs.get_account_data_handler()
44 self.clock = hs.get_clock()
45
46 @staticmethod
47 async def _serialize_payload(user_id, account_data_type, content):
48 payload = {
49 "content": content,
50 }
51
52 return payload
53
54 async def _handle_request(self, request, user_id, account_data_type):
55 content = parse_json_object_from_request(request)
56
57 max_stream_id = await self.handler.add_account_data_for_user(
58 user_id, account_data_type, content["content"]
59 )
60
61 return 200, {"max_stream_id": max_stream_id}
62
63
64 class ReplicationRoomAccountDataRestServlet(ReplicationEndpoint):
65 """Add room account data on the appropriate account data worker.
66
67 Request format:
68
69 POST /_synapse/replication/add_room_account_data/:user_id/:room_id/:account_data_type
70
71 {
72 "content": { ... },
73 }
74
75 """
76
77 NAME = "add_room_account_data"
78 PATH_ARGS = ("user_id", "room_id", "account_data_type")
79 CACHE = False
80
81 def __init__(self, hs):
82 super().__init__(hs)
83
84 self.handler = hs.get_account_data_handler()
85 self.clock = hs.get_clock()
86
87 @staticmethod
88 async def _serialize_payload(user_id, room_id, account_data_type, content):
89 payload = {
90 "content": content,
91 }
92
93 return payload
94
95 async def _handle_request(self, request, user_id, room_id, account_data_type):
96 content = parse_json_object_from_request(request)
97
98 max_stream_id = await self.handler.add_account_data_to_room(
99 user_id, room_id, account_data_type, content["content"]
100 )
101
102 return 200, {"max_stream_id": max_stream_id}
103
104
105 class ReplicationAddTagRestServlet(ReplicationEndpoint):
106 """Add tag on the appropriate account data worker.
107
108 Request format:
109
110 POST /_synapse/replication/add_tag/:user_id/:room_id/:tag
111
112 {
113 "content": { ... },
114 }
115
116 """
117
118 NAME = "add_tag"
119 PATH_ARGS = ("user_id", "room_id", "tag")
120 CACHE = False
121
122 def __init__(self, hs):
123 super().__init__(hs)
124
125 self.handler = hs.get_account_data_handler()
126 self.clock = hs.get_clock()
127
128 @staticmethod
129 async def _serialize_payload(user_id, room_id, tag, content):
130 payload = {
131 "content": content,
132 }
133
134 return payload
135
136 async def _handle_request(self, request, user_id, room_id, tag):
137 content = parse_json_object_from_request(request)
138
139 max_stream_id = await self.handler.add_tag_to_room(
140 user_id, room_id, tag, content["content"]
141 )
142
143 return 200, {"max_stream_id": max_stream_id}
144
145
146 class ReplicationRemoveTagRestServlet(ReplicationEndpoint):
147 """Remove tag on the appropriate account data worker.
148
149 Request format:
150
151 POST /_synapse/replication/remove_tag/:user_id/:room_id/:tag
152
153 {}
154
155 """
156
157 NAME = "remove_tag"
158 PATH_ARGS = (
159 "user_id",
160 "room_id",
161 "tag",
162 )
163 CACHE = False
164
165 def __init__(self, hs):
166 super().__init__(hs)
167
168 self.handler = hs.get_account_data_handler()
169 self.clock = hs.get_clock()
170
171 @staticmethod
172 async def _serialize_payload(user_id, room_id, tag):
173
174 return {}
175
176 async def _handle_request(self, request, user_id, room_id, tag):
177 max_stream_id = await self.handler.remove_tag_from_room(user_id, room_id, tag,)
178
179 return 200, {"max_stream_id": max_stream_id}
180
181
182 def register_servlets(hs, http_server):
183 ReplicationUserAccountDataRestServlet(hs).register(http_server)
184 ReplicationRoomAccountDataRestServlet(hs).register(http_server)
185 ReplicationAddTagRestServlet(hs).register(http_server)
186 ReplicationRemoveTagRestServlet(hs).register(http_server)
3232 database,
3333 stream_name="caches",
3434 instance_name=hs.get_instance_name(),
35 table="cache_invalidation_stream_by_instance",
36 instance_column="instance_name",
37 id_column="stream_id",
35 tables=[
36 (
37 "cache_invalidation_stream_by_instance",
38 "instance_name",
39 "stream_id",
40 )
41 ],
3842 sequence_name="cache_invalidation_stream_seq",
3943 writers=[],
4044 ) # type: Optional[MultiWriterIdGenerator]
1414 # limitations under the License.
1515
1616 from synapse.replication.slave.storage._base import BaseSlavedStore
17 from synapse.replication.slave.storage._slaved_id_tracker import SlavedIdTracker
18 from synapse.replication.tcp.streams import AccountDataStream, TagAccountDataStream
19 from synapse.storage.database import DatabasePool
2017 from synapse.storage.databases.main.account_data import AccountDataWorkerStore
2118 from synapse.storage.databases.main.tags import TagsWorkerStore
2219
2320
2421 class SlavedAccountDataStore(TagsWorkerStore, AccountDataWorkerStore, BaseSlavedStore):
25 def __init__(self, database: DatabasePool, db_conn, hs):
26 self._account_data_id_gen = SlavedIdTracker(
27 db_conn,
28 "account_data",
29 "stream_id",
30 extra_tables=[
31 ("room_account_data", "stream_id"),
32 ("room_tags_revisions", "stream_id"),
33 ],
34 )
35
36 super().__init__(database, db_conn, hs)
37
38 def get_max_account_data_stream_id(self):
39 return self._account_data_id_gen.get_current_token()
40
41 def process_replication_rows(self, stream_name, instance_name, token, rows):
42 if stream_name == TagAccountDataStream.NAME:
43 self._account_data_id_gen.advance(instance_name, token)
44 for row in rows:
45 self.get_tags_for_user.invalidate((row.user_id,))
46 self._account_data_stream_cache.entity_has_changed(row.user_id, token)
47 elif stream_name == AccountDataStream.NAME:
48 self._account_data_id_gen.advance(instance_name, token)
49 for row in rows:
50 if not row.room_id:
51 self.get_global_account_data_by_type_for_user.invalidate(
52 (row.data_type, row.user_id)
53 )
54 self.get_account_data_for_user.invalidate((row.user_id,))
55 self.get_account_data_for_room.invalidate((row.user_id, row.room_id))
56 self.get_account_data_for_room_and_type.invalidate(
57 (row.user_id, row.room_id, row.data_type)
58 )
59 self._account_data_stream_cache.entity_has_changed(row.user_id, token)
60 return super().process_replication_rows(stream_name, instance_name, token, rows)
22 pass
1313 # limitations under the License.
1414
1515 from synapse.replication.slave.storage._base import BaseSlavedStore
16 from synapse.replication.slave.storage._slaved_id_tracker import SlavedIdTracker
17 from synapse.replication.tcp.streams import ToDeviceStream
18 from synapse.storage.database import DatabasePool
1916 from synapse.storage.databases.main.deviceinbox import DeviceInboxWorkerStore
20 from synapse.util.caches.expiringcache import ExpiringCache
21 from synapse.util.caches.stream_change_cache import StreamChangeCache
2217
2318
2419 class SlavedDeviceInboxStore(DeviceInboxWorkerStore, BaseSlavedStore):
25 def __init__(self, database: DatabasePool, db_conn, hs):
26 super().__init__(database, db_conn, hs)
27 self._device_inbox_id_gen = SlavedIdTracker(
28 db_conn, "device_inbox", "stream_id"
29 )
30 self._device_inbox_stream_cache = StreamChangeCache(
31 "DeviceInboxStreamChangeCache",
32 self._device_inbox_id_gen.get_current_token(),
33 )
34 self._device_federation_outbox_stream_cache = StreamChangeCache(
35 "DeviceFederationOutboxStreamChangeCache",
36 self._device_inbox_id_gen.get_current_token(),
37 )
38
39 self._last_device_delete_cache = ExpiringCache(
40 cache_name="last_device_delete_cache",
41 clock=self._clock,
42 max_len=10000,
43 expiry_ms=30 * 60 * 1000,
44 )
45
46 def process_replication_rows(self, stream_name, instance_name, token, rows):
47 if stream_name == ToDeviceStream.NAME:
48 self._device_inbox_id_gen.advance(instance_name, token)
49 for row in rows:
50 if row.entity.startswith("@"):
51 self._device_inbox_stream_cache.entity_has_changed(
52 row.entity, token
53 )
54 else:
55 self._device_federation_outbox_stream_cache.entity_has_changed(
56 row.entity, token
57 )
58 return super().process_replication_rows(stream_name, instance_name, token, rows)
20 pass
1313 # See the License for the specific language governing permissions and
1414 # limitations under the License.
1515
16 from synapse.replication.tcp.streams import ReceiptsStream
17 from synapse.storage.database import DatabasePool
1816 from synapse.storage.databases.main.receipts import ReceiptsWorkerStore
1917
2018 from ._base import BaseSlavedStore
21 from ._slaved_id_tracker import SlavedIdTracker
2219
2320
2421 class SlavedReceiptsStore(ReceiptsWorkerStore, BaseSlavedStore):
25 def __init__(self, database: DatabasePool, db_conn, hs):
26 # We instantiate this first as the ReceiptsWorkerStore constructor
27 # needs to be able to call get_max_receipt_stream_id
28 self._receipts_id_gen = SlavedIdTracker(
29 db_conn, "receipts_linearized", "stream_id"
30 )
31
32 super().__init__(database, db_conn, hs)
33
34 def get_max_receipt_stream_id(self):
35 return self._receipts_id_gen.get_current_token()
36
37 def invalidate_caches_for_receipt(self, room_id, receipt_type, user_id):
38 self.get_receipts_for_user.invalidate((user_id, receipt_type))
39 self._get_linearized_receipts_for_room.invalidate_many((room_id,))
40 self.get_last_receipt_event_id_for_user.invalidate(
41 (user_id, room_id, receipt_type)
42 )
43 self._invalidate_get_users_with_receipts_in_room(room_id, receipt_type, user_id)
44 self.get_receipts_for_room.invalidate((room_id, receipt_type))
45
46 def process_replication_rows(self, stream_name, instance_name, token, rows):
47 if stream_name == ReceiptsStream.NAME:
48 self._receipts_id_gen.advance(instance_name, token)
49 for row in rows:
50 self.invalidate_caches_for_receipt(
51 row.room_id, row.receipt_type, row.user_id
52 )
53 self._receipts_stream_cache.entity_has_changed(row.room_id, token)
54
55 return super().process_replication_rows(stream_name, instance_name, token, rows)
22 pass
5050 from synapse.replication.tcp.protocol import AbstractConnection
5151 from synapse.replication.tcp.streams import (
5252 STREAMS_MAP,
53 AccountDataStream,
5354 BackfillStream,
5455 CachesStream,
5556 EventsStream,
5657 FederationStream,
58 ReceiptsStream,
5759 Stream,
60 TagAccountDataStream,
61 ToDeviceStream,
5862 TypingStream,
5963 )
6064
114118
115119 continue
116120
121 if isinstance(stream, ToDeviceStream):
122 # Only add ToDeviceStream as a source on instances in charge of
123 # sending to device messages.
124 if hs.get_instance_name() in hs.config.worker.writers.to_device:
125 self._streams_to_replicate.append(stream)
126
127 continue
128
117129 if isinstance(stream, TypingStream):
118130 # Only add TypingStream as a source on the instance in charge of
119131 # typing.
120132 if hs.config.worker.writers.typing == hs.get_instance_name():
133 self._streams_to_replicate.append(stream)
134
135 continue
136
137 if isinstance(stream, (AccountDataStream, TagAccountDataStream)):
138 # Only add AccountDataStream and TagAccountDataStream as a source on the
139 # instance in charge of account_data persistence.
140 if hs.get_instance_name() in hs.config.worker.writers.account_data:
141 self._streams_to_replicate.append(stream)
142
143 continue
144
145 if isinstance(stream, ReceiptsStream):
146 # Only add ReceiptsStream as a source on the instance in charge of
147 # receipts.
148 if hs.get_instance_name() in hs.config.worker.writers.receipts:
121149 self._streams_to_replicate.append(stream)
122150
123151 continue
0 <html>
1 <head>
2 <title>Authentication Failed</title>
3 </head>
4 <body>
5 <div>
6 <p>
7 We were unable to validate your <tt>{{server_name | e}}</tt> account via
8 single-sign-on (SSO), because the SSO Identity Provider returned
9 different details than when you logged in.
10 </p>
11 <p>
12 Try the operation again, and ensure that you use the same details on
13 the Identity Provider as when you log into your account.
14 </p>
15 </div>
16 </body>
17 </html>
0 <!DOCTYPE html>
1 <html lang="en">
2 <head>
3 <meta charset="UTF-8">
4 <link rel="stylesheet" href="/_matrix/static/client/login/style.css">
5 <title>{{server_name | e}} Login</title>
6 </head>
7 <body>
8 <div id="container">
9 <h1 id="title">{{server_name | e}} Login</h1>
10 <div class="login_flow">
11 <p>Choose one of the following identity providers:</p>
12 <form>
13 <input type="hidden" name="redirectUrl" value="{{redirect_url | e}}">
14 <ul class="radiobuttons">
15 {% for p in providers %}
16 <li>
17 <input type="radio" name="idp" id="prov{{loop.index}}" value="{{p.idp_id}}">
18 <label for="prov{{loop.index}}">{{p.idp_name | e}}</label>
19 {% if p.idp_icon %}
20 <img src="{{p.idp_icon | mxc_to_http(32, 32)}}"/>
21 {% endif %}
22 </li>
23 {% endfor %}
24 </ul>
25 <input type="submit" class="button button--full-width" id="button-submit" value="Submit">
26 </form>
27 </div>
28 </div>
29 </body>
30 </html>
1414 # limitations under the License.
1515
1616 import logging
17 from typing import TYPE_CHECKING, Tuple
18
19 from twisted.web.http import Request
1720
1821 from synapse.api.errors import AuthError, Codes, NotFoundError, SynapseError
1922 from synapse.http.servlet import RestServlet, parse_boolean, parse_integer
2225 assert_requester_is_admin,
2326 assert_user_is_admin,
2427 )
28 from synapse.types import JsonDict
29
30 if TYPE_CHECKING:
31 from synapse.app.homeserver import HomeServer
2532
2633 logger = logging.getLogger(__name__)
2734
3845 admin_patterns("/quarantine_media/(?P<room_id>[^/]+)")
3946 )
4047
41 def __init__(self, hs):
42 self.store = hs.get_datastore()
43 self.auth = hs.get_auth()
44
45 async def on_POST(self, request, room_id: str):
48 def __init__(self, hs: "HomeServer"):
49 self.store = hs.get_datastore()
50 self.auth = hs.get_auth()
51
52 async def on_POST(self, request: Request, room_id: str) -> Tuple[int, JsonDict]:
4653 requester = await self.auth.get_user_by_req(request)
4754 await assert_user_is_admin(self.auth, requester.user)
4855
6370
6471 PATTERNS = admin_patterns("/user/(?P<user_id>[^/]+)/media/quarantine")
6572
66 def __init__(self, hs):
67 self.store = hs.get_datastore()
68 self.auth = hs.get_auth()
69
70 async def on_POST(self, request, user_id: str):
73 def __init__(self, hs: "HomeServer"):
74 self.store = hs.get_datastore()
75 self.auth = hs.get_auth()
76
77 async def on_POST(self, request: Request, user_id: str) -> Tuple[int, JsonDict]:
7178 requester = await self.auth.get_user_by_req(request)
7279 await assert_user_is_admin(self.auth, requester.user)
7380
9097 "/media/quarantine/(?P<server_name>[^/]+)/(?P<media_id>[^/]+)"
9198 )
9299
93 def __init__(self, hs):
94 self.store = hs.get_datastore()
95 self.auth = hs.get_auth()
96
97 async def on_POST(self, request, server_name: str, media_id: str):
100 def __init__(self, hs: "HomeServer"):
101 self.store = hs.get_datastore()
102 self.auth = hs.get_auth()
103
104 async def on_POST(
105 self, request: Request, server_name: str, media_id: str
106 ) -> Tuple[int, JsonDict]:
98107 requester = await self.auth.get_user_by_req(request)
99108 await assert_user_is_admin(self.auth, requester.user)
100109
108117 return 200, {}
109118
110119
120 class ProtectMediaByID(RestServlet):
121 """Protect local media from being quarantined.
122 """
123
124 PATTERNS = admin_patterns("/media/protect/(?P<media_id>[^/]+)")
125
126 def __init__(self, hs: "HomeServer"):
127 self.store = hs.get_datastore()
128 self.auth = hs.get_auth()
129
130 async def on_POST(self, request: Request, media_id: str) -> Tuple[int, JsonDict]:
131 requester = await self.auth.get_user_by_req(request)
132 await assert_user_is_admin(self.auth, requester.user)
133
134 logging.info("Protecting local media by ID: %s", media_id)
135
136 # Quarantine this media id
137 await self.store.mark_local_media_as_safe(media_id)
138
139 return 200, {}
140
141
111142 class ListMediaInRoom(RestServlet):
112143 """Lists all of the media in a given room.
113144 """
114145
115146 PATTERNS = admin_patterns("/room/(?P<room_id>[^/]+)/media")
116147
117 def __init__(self, hs):
118 self.store = hs.get_datastore()
119 self.auth = hs.get_auth()
120
121 async def on_GET(self, request, room_id):
148 def __init__(self, hs: "HomeServer"):
149 self.store = hs.get_datastore()
150 self.auth = hs.get_auth()
151
152 async def on_GET(self, request: Request, room_id: str) -> Tuple[int, JsonDict]:
122153 requester = await self.auth.get_user_by_req(request)
123154 is_admin = await self.auth.is_server_admin(requester.user)
124155 if not is_admin:
132163 class PurgeMediaCacheRestServlet(RestServlet):
133164 PATTERNS = admin_patterns("/purge_media_cache")
134165
135 def __init__(self, hs):
166 def __init__(self, hs: "HomeServer"):
136167 self.media_repository = hs.get_media_repository()
137168 self.auth = hs.get_auth()
138169
139 async def on_POST(self, request):
170 async def on_POST(self, request: Request) -> Tuple[int, JsonDict]:
140171 await assert_requester_is_admin(self.auth, request)
141172
142173 before_ts = parse_integer(request, "before_ts", required=True)
153184
154185 PATTERNS = admin_patterns("/media/(?P<server_name>[^/]+)/(?P<media_id>[^/]+)")
155186
156 def __init__(self, hs):
187 def __init__(self, hs: "HomeServer"):
157188 self.store = hs.get_datastore()
158189 self.auth = hs.get_auth()
159190 self.server_name = hs.hostname
160191 self.media_repository = hs.get_media_repository()
161192
162 async def on_DELETE(self, request, server_name: str, media_id: str):
193 async def on_DELETE(
194 self, request: Request, server_name: str, media_id: str
195 ) -> Tuple[int, JsonDict]:
163196 await assert_requester_is_admin(self.auth, request)
164197
165198 if self.server_name != server_name:
181214
182215 PATTERNS = admin_patterns("/media/(?P<server_name>[^/]+)/delete")
183216
184 def __init__(self, hs):
217 def __init__(self, hs: "HomeServer"):
185218 self.store = hs.get_datastore()
186219 self.auth = hs.get_auth()
187220 self.server_name = hs.hostname
188221 self.media_repository = hs.get_media_repository()
189222
190 async def on_POST(self, request, server_name: str):
223 async def on_POST(self, request: Request, server_name: str) -> Tuple[int, JsonDict]:
191224 await assert_requester_is_admin(self.auth, request)
192225
193226 before_ts = parse_integer(request, "before_ts", required=True)
221254 return 200, {"deleted_media": deleted_media, "total": total}
222255
223256
224 def register_servlets_for_media_repo(hs, http_server):
257 def register_servlets_for_media_repo(hs: "HomeServer", http_server):
225258 """
226259 Media repo specific APIs.
227260 """
229262 QuarantineMediaInRoom(hs).register(http_server)
230263 QuarantineMediaByID(hs).register(http_server)
231264 QuarantineMediaByUser(hs).register(http_server)
265 ProtectMediaByID(hs).register(http_server)
232266 ListMediaInRoom(hs).register(http_server)
233267 DeleteMediaByID(hs).register(http_server)
234268 DeleteMediaByDateSize(hs).register(http_server)
243243
244244 if deactivate and not user["deactivated"]:
245245 await self.deactivate_account_handler.deactivate_account(
246 target_user.to_string(), False
246 target_user.to_string(), False, requester, by_admin=True
247247 )
248248 elif not deactivate and user["deactivated"]:
249249 if "password" not in body:
485485 class DeactivateAccountRestServlet(RestServlet):
486486 PATTERNS = admin_patterns("/deactivate/(?P<target_user_id>[^/]*)")
487487
488 def __init__(self, hs):
488 def __init__(self, hs: "HomeServer"):
489489 self._deactivate_account_handler = hs.get_deactivate_account_handler()
490490 self.auth = hs.get_auth()
491
492 async def on_POST(self, request, target_user_id):
493 await assert_requester_is_admin(self.auth, request)
491 self.is_mine = hs.is_mine
492 self.store = hs.get_datastore()
493
494 async def on_POST(self, request: str, target_user_id: str) -> Tuple[int, JsonDict]:
495 requester = await self.auth.get_user_by_req(request)
496 await assert_user_is_admin(self.auth, requester.user)
497
498 if not self.is_mine(UserID.from_string(target_user_id)):
499 raise SynapseError(400, "Can only deactivate local users")
500
501 if not await self.store.get_user_by_id(target_user_id):
502 raise NotFoundError("User not found")
503
494504 body = parse_json_object_from_request(request, allow_empty_body=True)
495505 erase = body.get("erase", False)
496506 if not isinstance(erase, bool):
500510 Codes.BAD_JSON,
501511 )
502512
503 UserID.from_string(target_user_id)
504
505513 result = await self._deactivate_account_handler.deactivate_account(
506 target_user_id, erase
514 target_user_id, erase, requester, by_admin=True
507515 )
508516 if result:
509517 id_server_unbind_result = "success"
713721 async def on_GET(self, request, user_id):
714722 await assert_requester_is_admin(self.auth, request)
715723
716 if not self.is_mine(UserID.from_string(user_id)):
717 raise SynapseError(400, "Can only lookup local users")
718
719 user = await self.store.get_user_by_id(user_id)
720 if user is None:
721 raise NotFoundError("Unknown user")
722
723724 room_ids = await self.store.get_rooms_for_user(user_id)
724725 ret = {"joined_rooms": list(room_ids), "total": len(room_ids)}
725726 return 200, ret
310310 return result
311311
312312
313 class BaseSSORedirectServlet(RestServlet):
314 """Common base class for /login/sso/redirect impls"""
315
313 class SsoRedirectServlet(RestServlet):
316314 PATTERNS = client_patterns("/login/(cas|sso)/redirect", v1=True)
317315
316 def __init__(self, hs: "HomeServer"):
317 # make sure that the relevant handlers are instantiated, so that they
318 # register themselves with the main SSOHandler.
319 if hs.config.cas_enabled:
320 hs.get_cas_handler()
321 if hs.config.saml2_enabled:
322 hs.get_saml_handler()
323 if hs.config.oidc_enabled:
324 hs.get_oidc_handler()
325 self._sso_handler = hs.get_sso_handler()
326
318327 async def on_GET(self, request: SynapseRequest):
319 args = request.args
320 if b"redirectUrl" not in args:
321 return 400, "Redirect URL not specified for SSO auth"
322 client_redirect_url = args[b"redirectUrl"][0]
323 sso_url = await self.get_sso_url(request, client_redirect_url)
328 client_redirect_url = parse_string(
329 request, "redirectUrl", required=True, encoding=None
330 )
331 sso_url = await self._sso_handler.handle_redirect_request(
332 request, client_redirect_url
333 )
334 logger.info("Redirecting to %s", sso_url)
324335 request.redirect(sso_url)
325336 finish_request(request)
326
327 async def get_sso_url(
328 self, request: SynapseRequest, client_redirect_url: bytes
329 ) -> bytes:
330 """Get the URL to redirect to, to perform SSO auth
331
332 Args:
333 request: The client request to redirect.
334 client_redirect_url: the URL that we should redirect the
335 client to when everything is done
336
337 Returns:
338 URL to redirect to
339 """
340 # to be implemented by subclasses
341 raise NotImplementedError()
342
343
344 class CasRedirectServlet(BaseSSORedirectServlet):
345 def __init__(self, hs):
346 self._cas_handler = hs.get_cas_handler()
347
348 async def get_sso_url(
349 self, request: SynapseRequest, client_redirect_url: bytes
350 ) -> bytes:
351 return self._cas_handler.get_redirect_url(
352 {"redirectUrl": client_redirect_url}
353 ).encode("ascii")
354337
355338
356339 class CasTicketServlet(RestServlet):
378361 )
379362
380363
381 class SAMLRedirectServlet(BaseSSORedirectServlet):
382 PATTERNS = client_patterns("/login/sso/redirect", v1=True)
383
384 def __init__(self, hs):
385 self._saml_handler = hs.get_saml_handler()
386
387 async def get_sso_url(
388 self, request: SynapseRequest, client_redirect_url: bytes
389 ) -> bytes:
390 return self._saml_handler.handle_redirect_request(client_redirect_url)
391
392
393 class OIDCRedirectServlet(BaseSSORedirectServlet):
394 """Implementation for /login/sso/redirect for the OIDC login flow."""
395
396 PATTERNS = client_patterns("/login/sso/redirect", v1=True)
397
398 def __init__(self, hs):
399 self._oidc_handler = hs.get_oidc_handler()
400
401 async def get_sso_url(
402 self, request: SynapseRequest, client_redirect_url: bytes
403 ) -> bytes:
404 return await self._oidc_handler.handle_redirect_request(
405 request, client_redirect_url
406 )
407
408
409364 def register_servlets(hs, http_server):
410365 LoginRestServlet(hs).register(http_server)
366 SsoRedirectServlet(hs).register(http_server)
411367 if hs.config.cas_enabled:
412 CasRedirectServlet(hs).register(http_server)
413368 CasTicketServlet(hs).register(http_server)
414 elif hs.config.saml2_enabled:
415 SAMLRedirectServlet(hs).register(http_server)
416 elif hs.config.oidc_enabled:
417 OIDCRedirectServlet(hs).register(http_server)
4545 from synapse.streams.config import PaginationConfig
4646 from synapse.types import RoomAlias, RoomID, StreamToken, ThirdPartyInstanceID, UserID
4747 from synapse.util import json_decoder
48 from synapse.util.stringutils import random_string
48 from synapse.util.stringutils import parse_and_validate_server_name, random_string
4949
5050 if TYPE_CHECKING:
5151 import synapse.server
346346 # provided.
347347 if server:
348348 raise e
349 else:
350 pass
351349
352350 limit = parse_integer(request, "limit", 0)
353351 since_token = parse_string(request, "since", None)
358356
359357 handler = self.hs.get_room_list_handler()
360358 if server and server != self.hs.config.server_name:
359 # Ensure the server is valid.
360 try:
361 parse_and_validate_server_name(server)
362 except ValueError:
363 raise SynapseError(
364 400, "Invalid server name: %s" % (server,), Codes.INVALID_PARAM,
365 )
366
361367 try:
362368 data = await handler.get_remote_public_room_list(
363369 server, limit=limit, since_token=since_token
401407
402408 handler = self.hs.get_room_list_handler()
403409 if server and server != self.hs.config.server_name:
410 # Ensure the server is valid.
411 try:
412 parse_and_validate_server_name(server)
413 except ValueError:
414 raise SynapseError(
415 400, "Invalid server name: %s" % (server,), Codes.INVALID_PARAM,
416 )
417
404418 try:
405419 data = await handler.get_remote_public_room_list(
406420 server,
1919 from typing import TYPE_CHECKING
2020 from urllib.parse import urlparse
2121
22 if TYPE_CHECKING:
23 from synapse.app.homeserver import HomeServer
24
2522 from synapse.api.constants import LoginType
2623 from synapse.api.errors import (
2724 Codes,
3027 ThreepidValidationError,
3128 )
3229 from synapse.config.emailconfig import ThreepidBehaviour
30 from synapse.handlers.ui_auth import UIAuthSessionDataConstants
3331 from synapse.http.server import finish_request, respond_with_html
3432 from synapse.http.servlet import (
3533 RestServlet,
4543
4644 from ._base import client_patterns, interactive_auth_handler
4745
46 if TYPE_CHECKING:
47 from synapse.app.homeserver import HomeServer
48
49
4850 logger = logging.getLogger(__name__)
4951
5052
188190 requester = await self.auth.get_user_by_req(request)
189191 try:
190192 params, session_id = await self.auth_handler.validate_user_via_ui_auth(
191 requester,
193 requester, request, body, "modify your account password",
194 )
195 except InteractiveAuthIncompleteError as e:
196 # The user needs to provide more steps to complete auth, but
197 # they're not required to provide the password again.
198 #
199 # If a password is available now, hash the provided password and
200 # store it for later.
201 if new_password:
202 password_hash = await self.auth_handler.hash(new_password)
203 await self.auth_handler.set_session_data(
204 e.session_id,
205 UIAuthSessionDataConstants.PASSWORD_HASH,
206 password_hash,
207 )
208 raise
209 user_id = requester.user.to_string()
210 else:
211 requester = None
212 try:
213 result, params, session_id = await self.auth_handler.check_ui_auth(
214 [[LoginType.EMAIL_IDENTITY]],
192215 request,
193216 body,
194 self.hs.get_ip_from_request(request),
195217 "modify your account password",
196218 )
197219 except InteractiveAuthIncompleteError as e:
203225 if new_password:
204226 password_hash = await self.auth_handler.hash(new_password)
205227 await self.auth_handler.set_session_data(
206 e.session_id, "password_hash", password_hash
207 )
208 raise
209 user_id = requester.user.to_string()
210 else:
211 requester = None
212 try:
213 result, params, session_id = await self.auth_handler.check_ui_auth(
214 [[LoginType.EMAIL_IDENTITY]],
215 request,
216 body,
217 self.hs.get_ip_from_request(request),
218 "modify your account password",
219 )
220 except InteractiveAuthIncompleteError as e:
221 # The user needs to provide more steps to complete auth, but
222 # they're not required to provide the password again.
223 #
224 # If a password is available now, hash the provided password and
225 # store it for later.
226 if new_password:
227 password_hash = await self.auth_handler.hash(new_password)
228 await self.auth_handler.set_session_data(
229 e.session_id, "password_hash", password_hash
228 e.session_id,
229 UIAuthSessionDataConstants.PASSWORD_HASH,
230 password_hash,
230231 )
231232 raise
232233
259260 password_hash = await self.auth_handler.hash(new_password)
260261 elif session_id is not None:
261262 password_hash = await self.auth_handler.get_session_data(
262 session_id, "password_hash", None
263 session_id, UIAuthSessionDataConstants.PASSWORD_HASH, None
263264 )
264265 else:
265266 # UI validation was skipped, but the request did not include a new
303304 # allow ASes to deactivate their own users
304305 if requester.app_service:
305306 await self._deactivate_account_handler.deactivate_account(
306 requester.user.to_string(), erase
307 requester.user.to_string(), erase, requester
307308 )
308309 return 200, {}
309310
310311 await self.auth_handler.validate_user_via_ui_auth(
312 requester, request, body, "deactivate your account",
313 )
314 result = await self._deactivate_account_handler.deactivate_account(
315 requester.user.to_string(),
316 erase,
311317 requester,
312 request,
313 body,
314 self.hs.get_ip_from_request(request),
315 "deactivate your account",
316 )
317 result = await self._deactivate_account_handler.deactivate_account(
318 requester.user.to_string(), erase, id_server=body.get("id_server")
318 id_server=body.get("id_server"),
319319 )
320320 if result:
321321 id_server_unbind_result = "success"
694694 assert_valid_client_secret(client_secret)
695695
696696 await self.auth_handler.validate_user_via_ui_auth(
697 requester,
698 request,
699 body,
700 self.hs.get_ip_from_request(request),
701 "add a third-party identifier to your account",
697 requester, request, body, "add a third-party identifier to your account",
702698 )
703699
704700 validation_session = await self.identity_handler.validate_threepid_session(
3636 super().__init__()
3737 self.auth = hs.get_auth()
3838 self.store = hs.get_datastore()
39 self.notifier = hs.get_notifier()
40 self._is_worker = hs.config.worker_app is not None
39 self.handler = hs.get_account_data_handler()
4140
4241 async def on_PUT(self, request, user_id, account_data_type):
43 if self._is_worker:
44 raise Exception("Cannot handle PUT /account_data on worker")
45
4642 requester = await self.auth.get_user_by_req(request)
4743 if user_id != requester.user.to_string():
4844 raise AuthError(403, "Cannot add account data for other users.")
4945
5046 body = parse_json_object_from_request(request)
5147
52 max_id = await self.store.add_account_data_for_user(
53 user_id, account_data_type, body
54 )
55
56 self.notifier.on_new_event("account_data_key", max_id, users=[user_id])
48 await self.handler.add_account_data_for_user(user_id, account_data_type, body)
5749
5850 return 200, {}
5951
8880 super().__init__()
8981 self.auth = hs.get_auth()
9082 self.store = hs.get_datastore()
91 self.notifier = hs.get_notifier()
92 self._is_worker = hs.config.worker_app is not None
83 self.handler = hs.get_account_data_handler()
9384
9485 async def on_PUT(self, request, user_id, room_id, account_data_type):
95 if self._is_worker:
96 raise Exception("Cannot handle PUT /account_data on worker")
97
9886 requester = await self.auth.get_user_by_req(request)
9987 if user_id != requester.user.to_string():
10088 raise AuthError(403, "Cannot add account data for other users.")
10896 " Use /rooms/!roomId:server.name/read_markers",
10997 )
11098
111 max_id = await self.store.add_account_data_to_room(
99 await self.handler.add_account_data_to_room(
112100 user_id, room_id, account_data_type, body
113101 )
114
115 self.notifier.on_new_event("account_data_key", max_id, users=[user_id])
116102
117103 return 200, {}
118104
1313 # limitations under the License.
1414
1515 import logging
16 from typing import TYPE_CHECKING
1617
1718 from synapse.api.constants import LoginType
1819 from synapse.api.errors import SynapseError
2122 from synapse.http.servlet import RestServlet, parse_string
2223
2324 from ._base import client_patterns
25
26 if TYPE_CHECKING:
27 from synapse.server import HomeServer
2428
2529 logger = logging.getLogger(__name__)
2630
3438
3539 PATTERNS = client_patterns(r"/auth/(?P<stagetype>[\w\.]*)/fallback/web")
3640
37 def __init__(self, hs):
41 def __init__(self, hs: "HomeServer"):
3842 super().__init__()
3943 self.hs = hs
4044 self.auth = hs.get_auth()
4145 self.auth_handler = hs.get_auth_handler()
4246 self.registration_handler = hs.get_registration_handler()
43
44 # SSO configuration.
45 self._cas_enabled = hs.config.cas_enabled
46 if self._cas_enabled:
47 self._cas_handler = hs.get_cas_handler()
48 self._cas_server_url = hs.config.cas_server_url
49 self._cas_service_url = hs.config.cas_service_url
50 self._saml_enabled = hs.config.saml2_enabled
51 if self._saml_enabled:
52 self._saml_handler = hs.get_saml_handler()
53 self._oidc_enabled = hs.config.oidc_enabled
54 if self._oidc_enabled:
55 self._oidc_handler = hs.get_oidc_handler()
56 self._cas_server_url = hs.config.cas_server_url
57 self._cas_service_url = hs.config.cas_service_url
58
5947 self.recaptcha_template = hs.config.recaptcha_template
6048 self.terms_template = hs.config.terms_template
6149 self.success_template = hs.config.fallback_success_template
8472 elif stagetype == LoginType.SSO:
8573 # Display a confirmation page which prompts the user to
8674 # re-authenticate with their SSO provider.
87 if self._cas_enabled:
88 # Generate a request to CAS that redirects back to an endpoint
89 # to verify the successful authentication.
90 sso_redirect_url = self._cas_handler.get_redirect_url(
91 {"session": session},
92 )
93
94 elif self._saml_enabled:
95 # Some SAML identity providers (e.g. Google) require a
96 # RelayState parameter on requests. It is not necessary here, so
97 # pass in a dummy redirect URL (which will never get used).
98 client_redirect_url = b"unused"
99 sso_redirect_url = self._saml_handler.handle_redirect_request(
100 client_redirect_url, session
101 )
102
103 elif self._oidc_enabled:
104 client_redirect_url = b""
105 sso_redirect_url = await self._oidc_handler.handle_redirect_request(
106 request, client_redirect_url, session
107 )
108
109 else:
110 raise SynapseError(400, "Homeserver not configured for SSO.")
111
112 html = await self.auth_handler.start_sso_ui_auth(sso_redirect_url, session)
75 html = await self.auth_handler.start_sso_ui_auth(request, session)
11376
11477 else:
11578 raise SynapseError(404, "Unknown auth stage type")
13396 authdict = {"response": response, "session": session}
13497
13598 success = await self.auth_handler.add_oob_auth(
136 LoginType.RECAPTCHA, authdict, self.hs.get_ip_from_request(request)
99 LoginType.RECAPTCHA, authdict, request.getClientIP()
137100 )
138101
139102 if success:
149112 authdict = {"session": session}
150113
151114 success = await self.auth_handler.add_oob_auth(
152 LoginType.TERMS, authdict, self.hs.get_ip_from_request(request)
115 LoginType.TERMS, authdict, request.getClientIP()
153116 )
154117
155118 if success:
8282 assert_params_in_dict(body, ["devices"])
8383
8484 await self.auth_handler.validate_user_via_ui_auth(
85 requester,
86 request,
87 body,
88 self.hs.get_ip_from_request(request),
89 "remove device(s) from your account",
85 requester, request, body, "remove device(s) from your account",
9086 )
9187
9288 await self.device_handler.delete_devices(
132128 raise
133129
134130 await self.auth_handler.validate_user_via_ui_auth(
135 requester,
136 request,
137 body,
138 self.hs.get_ip_from_request(request),
139 "remove a device from your account",
131 requester, request, body, "remove a device from your account",
140132 )
141133
142134 await self.device_handler.delete_device(requester.user.to_string(), device_id)
270270 body = parse_json_object_from_request(request)
271271
272272 await self.auth_handler.validate_user_via_ui_auth(
273 requester,
274 request,
275 body,
276 self.hs.get_ip_from_request(request),
277 "add a device signing key to your account",
273 requester, request, body, "add a device signing key to your account",
278274 )
279275
280276 result = await self.e2e_keys_handler.upload_signing_keys_for_user(user_id, body)
3737 from synapse.config.registration import RegistrationConfig
3838 from synapse.config.server import is_threepid_reserved
3939 from synapse.handlers.auth import AuthHandler
40 from synapse.handlers.ui_auth import UIAuthSessionDataConstants
4041 from synapse.http.server import finish_request, respond_with_html
4142 from synapse.http.servlet import (
4243 RestServlet,
352353 403, "Registration has been disabled", errcode=Codes.FORBIDDEN
353354 )
354355
355 ip = self.hs.get_ip_from_request(request)
356 ip = request.getClientIP()
356357 with self.ratelimiter.ratelimit(ip) as wait_deferred:
357358 await wait_deferred
358359
493494 # user here. We carry on and go through the auth checks though,
494495 # for paranoia.
495496 registered_user_id = await self.auth_handler.get_session_data(
496 session_id, "registered_user_id", None
497 session_id, UIAuthSessionDataConstants.REGISTERED_USER_ID, None
497498 )
498499 # Extract the previously-hashed password from the session.
499500 password_hash = await self.auth_handler.get_session_data(
500 session_id, "password_hash", None
501 session_id, UIAuthSessionDataConstants.PASSWORD_HASH, None
501502 )
502503
503504 # Ensure that the username is valid.
512513 # not this will raise a user-interactive auth error.
513514 try:
514515 auth_result, params, session_id = await self.auth_handler.check_ui_auth(
515 self._registration_flows,
516 request,
517 body,
518 self.hs.get_ip_from_request(request),
519 "register a new account",
516 self._registration_flows, request, body, "register a new account",
520517 )
521518 except InteractiveAuthIncompleteError as e:
522519 # The user needs to provide more steps to complete auth.
531528 if not password_hash and password:
532529 password_hash = await self.auth_handler.hash(password)
533530 await self.auth_handler.set_session_data(
534 e.session_id, "password_hash", password_hash
531 e.session_id,
532 UIAuthSessionDataConstants.PASSWORD_HASH,
533 password_hash,
535534 )
536535 raise
537536
632631 # Remember that the user account has been registered (and the user
633632 # ID it was registered with, since it might not have been specified).
634633 await self.auth_handler.set_session_data(
635 session_id, "registered_user_id", registered_user_id
634 session_id,
635 UIAuthSessionDataConstants.REGISTERED_USER_ID,
636 registered_user_id,
636637 )
637638
638639 registered = True
5757 def __init__(self, hs):
5858 super().__init__()
5959 self.auth = hs.get_auth()
60 self.store = hs.get_datastore()
61 self.notifier = hs.get_notifier()
60 self.handler = hs.get_account_data_handler()
6261
6362 async def on_PUT(self, request, user_id, room_id, tag):
6463 requester = await self.auth.get_user_by_req(request)
6766
6867 body = parse_json_object_from_request(request)
6968
70 max_id = await self.store.add_tag_to_room(user_id, room_id, tag, body)
71
72 self.notifier.on_new_event("account_data_key", max_id, users=[user_id])
69 await self.handler.add_tag_to_room(user_id, room_id, tag, body)
7370
7471 return 200, {}
7572
7875 if user_id != requester.user.to_string():
7976 raise AuthError(403, "Cannot add tags for other users.")
8077
81 max_id = await self.store.remove_tag_from_room(user_id, room_id, tag)
82
83 self.notifier.on_new_event("account_data_key", max_id, users=[user_id])
78 await self.handler.remove_tag_from_room(user_id, room_id, tag)
8479
8580 return 200, {}
8681
00 # -*- coding: utf-8 -*-
11 # Copyright 2014-2016 OpenMarket Ltd
2 # Copyright 2019 New Vector Ltd
2 # Copyright 2019-2021 The Matrix.org Foundation C.I.C.
33 #
44 # Licensed under the Apache License, Version 2.0 (the "License");
55 # you may not use this file except in compliance with the License.
1616 import logging
1717 import os
1818 import urllib
19 from typing import Awaitable
19 from typing import Awaitable, Dict, Generator, List, Optional, Tuple
2020
2121 from twisted.internet.interfaces import IConsumer
2222 from twisted.protocols.basic import FileSender
23 from twisted.web.http import Request
2324
2425 from synapse.api.errors import Codes, SynapseError, cs_error
2526 from synapse.http.server import finish_request, respond_with_json
4546 ]
4647
4748
48 def parse_media_id(request):
49 def parse_media_id(request: Request) -> Tuple[str, str, Optional[str]]:
4950 try:
5051 # This allows users to append e.g. /test.png to the URL. Useful for
5152 # clients that parse the URL to see content type.
6869 )
6970
7071
71 def respond_404(request):
72 def respond_404(request: Request) -> None:
7273 respond_with_json(
7374 request,
7475 404,
7879
7980
8081 async def respond_with_file(
81 request, media_type, file_path, file_size=None, upload_name=None
82 ):
82 request: Request,
83 media_type: str,
84 file_path: str,
85 file_size: Optional[int] = None,
86 upload_name: Optional[str] = None,
87 ) -> None:
8388 logger.debug("Responding with %r", file_path)
8489
8590 if os.path.isfile(file_path):
97102 respond_404(request)
98103
99104
100 def add_file_headers(request, media_type, file_size, upload_name):
105 def add_file_headers(
106 request: Request,
107 media_type: str,
108 file_size: Optional[int],
109 upload_name: Optional[str],
110 ) -> None:
101111 """Adds the correct response headers in preparation for responding with the
102112 media.
103113
104114 Args:
105 request (twisted.web.http.Request)
106 media_type (str): The media/content type.
107 file_size (int): Size in bytes of the media, if known.
108 upload_name (str): The name of the requested file, if any.
115 request
116 media_type: The media/content type.
117 file_size: Size in bytes of the media, if known.
118 upload_name: The name of the requested file, if any.
109119 """
110120
111121 def _quote(x):
152162 # select private. don't bother setting Expires as all our
153163 # clients are smart enough to be happy with Cache-Control
154164 request.setHeader(b"Cache-Control", b"public,max-age=86400,s-maxage=86400")
155 request.setHeader(b"Content-Length", b"%d" % (file_size,))
165 if file_size is not None:
166 request.setHeader(b"Content-Length", b"%d" % (file_size,))
156167
157168 # Tell web crawlers to not index, archive, or follow links in media. This
158169 # should help to prevent things in the media repo from showing up in web
183194 }
184195
185196
186 def _can_encode_filename_as_token(x):
197 def _can_encode_filename_as_token(x: str) -> bool:
187198 for c in x:
188199 # from RFC2616:
189200 #
205216
206217
207218 async def respond_with_responder(
208 request, responder, media_type, file_size, upload_name=None
209 ):
219 request: Request,
220 responder: "Optional[Responder]",
221 media_type: str,
222 file_size: Optional[int],
223 upload_name: Optional[str] = None,
224 ) -> None:
210225 """Responds to the request with given responder. If responder is None then
211226 returns 404.
212227
213228 Args:
214 request (twisted.web.http.Request)
215 responder (Responder|None)
216 media_type (str): The media/content type.
217 file_size (int|None): Size in bytes of the media. If not known it should be None
218 upload_name (str|None): The name of the requested file, if any.
229 request
230 responder
231 media_type: The media/content type.
232 file_size: Size in bytes of the media. If not known it should be None
233 upload_name: The name of the requested file, if any.
219234 """
220235 if request._disconnected:
221236 logger.warning(
307322 self.thumbnail_type = thumbnail_type
308323
309324
310 def get_filename_from_headers(headers):
325 def get_filename_from_headers(headers: Dict[bytes, List[bytes]]) -> Optional[str]:
311326 """
312327 Get the filename of the downloaded file by inspecting the
313328 Content-Disposition HTTP header.
314329
315330 Args:
316 headers (dict[bytes, list[bytes]]): The HTTP request headers.
331 headers: The HTTP request headers.
317332
318333 Returns:
319 A Unicode string of the filename, or None.
334 The filename, or None.
320335 """
321336 content_disposition = headers.get(b"Content-Disposition", [b""])
322337
323338 # No header, bail out.
324339 if not content_disposition[0]:
325 return
340 return None
326341
327342 _, params = _parse_header(content_disposition[0])
328343
355370 return upload_name
356371
357372
358 def _parse_header(line):
373 def _parse_header(line: bytes) -> Tuple[bytes, Dict[bytes, bytes]]:
359374 """Parse a Content-type like header.
360375
361376 Cargo-culted from `cgi`, but works on bytes rather than strings.
362377
363378 Args:
364 line (bytes): header to be parsed
379 line: header to be parsed
365380
366381 Returns:
367 Tuple[bytes, dict[bytes, bytes]]:
368 the main content-type, followed by the parameter dictionary
382 The main content-type, followed by the parameter dictionary
369383 """
370384 parts = _parseparam(b";" + line)
371385 key = next(parts)
385399 return key, pdict
386400
387401
388 def _parseparam(s):
402 def _parseparam(s: bytes) -> Generator[bytes, None, None]:
389403 """Generator which splits the input on ;, respecting double-quoted sequences
390404
391405 Cargo-culted from `cgi`, but works on bytes rather than strings.
392406
393407 Args:
394 s (bytes): header to be parsed
408 s: header to be parsed
395409
396410 Returns:
397 Iterable[bytes]: the split input
411 The split input
398412 """
399413 while s[:1] == b";":
400414 s = s[1:]
00 # -*- coding: utf-8 -*-
11 # Copyright 2018 Will Hunt <will@half-shot.uk>
2 # Copyright 2020-2021 The Matrix.org Foundation C.I.C.
23 #
34 # Licensed under the Apache License, Version 2.0 (the "License");
45 # you may not use this file except in compliance with the License.
1314 # limitations under the License.
1415 #
1516
17 from typing import TYPE_CHECKING
18
19 from twisted.web.http import Request
20
1621 from synapse.http.server import DirectServeJsonResource, respond_with_json
22
23 if TYPE_CHECKING:
24 from synapse.app.homeserver import HomeServer
1725
1826
1927 class MediaConfigResource(DirectServeJsonResource):
2028 isLeaf = True
2129
22 def __init__(self, hs):
30 def __init__(self, hs: "HomeServer"):
2331 super().__init__()
2432 config = hs.get_config()
2533 self.clock = hs.get_clock()
2634 self.auth = hs.get_auth()
2735 self.limits_dict = {"m.upload.size": config.max_upload_size}
2836
29 async def _async_render_GET(self, request):
37 async def _async_render_GET(self, request: Request) -> None:
3038 await self.auth.get_user_by_req(request)
3139 respond_with_json(request, 200, self.limits_dict, send_cors=True)
3240
33 async def _async_render_OPTIONS(self, request):
41 async def _async_render_OPTIONS(self, request: Request) -> None:
3442 respond_with_json(request, 200, {}, send_cors=True)
00 # -*- coding: utf-8 -*-
11 # Copyright 2014-2016 OpenMarket Ltd
2 # Copyright 2020-2021 The Matrix.org Foundation C.I.C.
23 #
34 # Licensed under the Apache License, Version 2.0 (the "License");
45 # you may not use this file except in compliance with the License.
1213 # See the License for the specific language governing permissions and
1314 # limitations under the License.
1415 import logging
16 from typing import TYPE_CHECKING
1517
16 import synapse.http.servlet
18 from twisted.web.http import Request
19
1720 from synapse.http.server import DirectServeJsonResource, set_cors_headers
21 from synapse.http.servlet import parse_boolean
1822
1923 from ._base import parse_media_id, respond_404
24
25 if TYPE_CHECKING:
26 from synapse.app.homeserver import HomeServer
27 from synapse.rest.media.v1.media_repository import MediaRepository
2028
2129 logger = logging.getLogger(__name__)
2230
2432 class DownloadResource(DirectServeJsonResource):
2533 isLeaf = True
2634
27 def __init__(self, hs, media_repo):
35 def __init__(self, hs: "HomeServer", media_repo: "MediaRepository"):
2836 super().__init__()
2937 self.media_repo = media_repo
3038 self.server_name = hs.hostname
3139
32 async def _async_render_GET(self, request):
40 async def _async_render_GET(self, request: Request) -> None:
3341 set_cors_headers(request)
3442 request.setHeader(
3543 b"Content-Security-Policy",
4856 if server_name == self.server_name:
4957 await self.media_repo.get_local_media(request, media_id, name)
5058 else:
51 allow_remote = synapse.http.servlet.parse_boolean(
52 request, "allow_remote", default=True
53 )
59 allow_remote = parse_boolean(request, "allow_remote", default=True)
5460 if not allow_remote:
5561 logger.info(
5662 "Rejecting request for remote media %s/%s due to allow_remote",
00 # -*- coding: utf-8 -*-
11 # Copyright 2014-2016 OpenMarket Ltd
2 # Copyright 2020-2021 The Matrix.org Foundation C.I.C.
23 #
34 # Licensed under the Apache License, Version 2.0 (the "License");
45 # you may not use this file except in compliance with the License.
1516 import functools
1617 import os
1718 import re
19 from typing import Callable, List
1820
1921 NEW_FORMAT_ID_RE = re.compile(r"^\d\d\d\d-\d\d-\d\d")
2022
2123
22 def _wrap_in_base_path(func):
24 def _wrap_in_base_path(func: "Callable[..., str]") -> "Callable[..., str]":
2325 """Takes a function that returns a relative path and turns it into an
2426 absolute path based on the location of the primary media store
2527 """
4042 to write to the backup media store (when one is configured)
4143 """
4244
43 def __init__(self, primary_base_path):
45 def __init__(self, primary_base_path: str):
4446 self.base_path = primary_base_path
4547
4648 def default_thumbnail_rel(
47 self, default_top_level, default_sub_type, width, height, content_type, method
48 ):
49 self,
50 default_top_level: str,
51 default_sub_type: str,
52 width: int,
53 height: int,
54 content_type: str,
55 method: str,
56 ) -> str:
4957 top_level_type, sub_type = content_type.split("/")
5058 file_name = "%i-%i-%s-%s-%s" % (width, height, top_level_type, sub_type, method)
5159 return os.path.join(
5462
5563 default_thumbnail = _wrap_in_base_path(default_thumbnail_rel)
5664
57 def local_media_filepath_rel(self, media_id):
65 def local_media_filepath_rel(self, media_id: str) -> str:
5866 return os.path.join("local_content", media_id[0:2], media_id[2:4], media_id[4:])
5967
6068 local_media_filepath = _wrap_in_base_path(local_media_filepath_rel)
6169
62 def local_media_thumbnail_rel(self, media_id, width, height, content_type, method):
70 def local_media_thumbnail_rel(
71 self, media_id: str, width: int, height: int, content_type: str, method: str
72 ) -> str:
6373 top_level_type, sub_type = content_type.split("/")
6474 file_name = "%i-%i-%s-%s-%s" % (width, height, top_level_type, sub_type, method)
6575 return os.path.join(
8595 media_id[4:],
8696 )
8797
88 def remote_media_filepath_rel(self, server_name, file_id):
98 def remote_media_filepath_rel(self, server_name: str, file_id: str) -> str:
8999 return os.path.join(
90100 "remote_content", server_name, file_id[0:2], file_id[2:4], file_id[4:]
91101 )
93103 remote_media_filepath = _wrap_in_base_path(remote_media_filepath_rel)
94104
95105 def remote_media_thumbnail_rel(
96 self, server_name, file_id, width, height, content_type, method
97 ):
106 self,
107 server_name: str,
108 file_id: str,
109 width: int,
110 height: int,
111 content_type: str,
112 method: str,
113 ) -> str:
98114 top_level_type, sub_type = content_type.split("/")
99115 file_name = "%i-%i-%s-%s-%s" % (width, height, top_level_type, sub_type, method)
100116 return os.path.join(
112128 # Should be removed after some time, when most of the thumbnails are stored
113129 # using the new path.
114130 def remote_media_thumbnail_rel_legacy(
115 self, server_name, file_id, width, height, content_type
131 self, server_name: str, file_id: str, width: int, height: int, content_type: str
116132 ):
117133 top_level_type, sub_type = content_type.split("/")
118134 file_name = "%i-%i-%s-%s" % (width, height, top_level_type, sub_type)
125141 file_name,
126142 )
127143
128 def remote_media_thumbnail_dir(self, server_name, file_id):
144 def remote_media_thumbnail_dir(self, server_name: str, file_id: str) -> str:
129145 return os.path.join(
130146 self.base_path,
131147 "remote_thumbnail",
135151 file_id[4:],
136152 )
137153
138 def url_cache_filepath_rel(self, media_id):
154 def url_cache_filepath_rel(self, media_id: str) -> str:
139155 if NEW_FORMAT_ID_RE.match(media_id):
140156 # Media id is of the form <DATE><RANDOM_STRING>
141157 # E.g.: 2017-09-28-fsdRDt24DS234dsf
145161
146162 url_cache_filepath = _wrap_in_base_path(url_cache_filepath_rel)
147163
148 def url_cache_filepath_dirs_to_delete(self, media_id):
164 def url_cache_filepath_dirs_to_delete(self, media_id: str) -> List[str]:
149165 "The dirs to try and remove if we delete the media_id file"
150166 if NEW_FORMAT_ID_RE.match(media_id):
151167 return [os.path.join(self.base_path, "url_cache", media_id[:10])]
155171 os.path.join(self.base_path, "url_cache", media_id[0:2]),
156172 ]
157173
158 def url_cache_thumbnail_rel(self, media_id, width, height, content_type, method):
174 def url_cache_thumbnail_rel(
175 self, media_id: str, width: int, height: int, content_type: str, method: str
176 ) -> str:
159177 # Media id is of the form <DATE><RANDOM_STRING>
160178 # E.g.: 2017-09-28-fsdRDt24DS234dsf
161179
177195
178196 url_cache_thumbnail = _wrap_in_base_path(url_cache_thumbnail_rel)
179197
180 def url_cache_thumbnail_directory(self, media_id):
198 def url_cache_thumbnail_directory(self, media_id: str) -> str:
181199 # Media id is of the form <DATE><RANDOM_STRING>
182200 # E.g.: 2017-09-28-fsdRDt24DS234dsf
183201
194212 media_id[4:],
195213 )
196214
197 def url_cache_thumbnail_dirs_to_delete(self, media_id):
215 def url_cache_thumbnail_dirs_to_delete(self, media_id: str) -> List[str]:
198216 "The dirs to try and remove if we delete the media_id thumbnails"
199217 # Media id is of the form <DATE><RANDOM_STRING>
200218 # E.g.: 2017-09-28-fsdRDt24DS234dsf
00 # -*- coding: utf-8 -*-
11 # Copyright 2014-2016 OpenMarket Ltd
2 # Copyright 2018 New Vector Ltd
2 # Copyright 2018-2021 The Matrix.org Foundation C.I.C.
33 #
44 # Licensed under the Apache License, Version 2.0 (the "License");
55 # you may not use this file except in compliance with the License.
1212 # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
1313 # See the License for the specific language governing permissions and
1414 # limitations under the License.
15
1615 import errno
1716 import logging
1817 import os
1918 import shutil
20 from typing import IO, Dict, List, Optional, Tuple
19 from io import BytesIO
20 from typing import IO, TYPE_CHECKING, Dict, List, Optional, Set, Tuple
2121
2222 import twisted.internet.error
2323 import twisted.web.http
5555 from .thumbnailer import Thumbnailer, ThumbnailError
5656 from .upload_resource import UploadResource
5757
58 if TYPE_CHECKING:
59 from synapse.app.homeserver import HomeServer
60
5861 logger = logging.getLogger(__name__)
5962
6063
6265
6366
6467 class MediaRepository:
65 def __init__(self, hs):
68 def __init__(self, hs: "HomeServer"):
6669 self.hs = hs
6770 self.auth = hs.get_auth()
6871 self.client = hs.get_federation_http_client()
7275 self.max_upload_size = hs.config.max_upload_size
7376 self.max_image_pixels = hs.config.max_image_pixels
7477
75 self.primary_base_path = hs.config.media_store_path
76 self.filepaths = MediaFilePaths(self.primary_base_path)
78 self.primary_base_path = hs.config.media_store_path # type: str
79 self.filepaths = MediaFilePaths(self.primary_base_path) # type: MediaFilePaths
7780
7881 self.dynamic_thumbnails = hs.config.dynamic_thumbnails
7982 self.thumbnail_requirements = hs.config.thumbnail_requirements
8083
8184 self.remote_media_linearizer = Linearizer(name="media_remote")
8285
83 self.recently_accessed_remotes = set()
84 self.recently_accessed_locals = set()
86 self.recently_accessed_remotes = set() # type: Set[Tuple[str, str]]
87 self.recently_accessed_locals = set() # type: Set[str]
8588
8689 self.federation_domain_whitelist = hs.config.federation_domain_whitelist
8790
112115 "update_recently_accessed_media", self._update_recently_accessed
113116 )
114117
115 async def _update_recently_accessed(self):
118 async def _update_recently_accessed(self) -> None:
116119 remote_media = self.recently_accessed_remotes
117120 self.recently_accessed_remotes = set()
118121
123126 local_media, remote_media, self.clock.time_msec()
124127 )
125128
126 def mark_recently_accessed(self, server_name, media_id):
129 def mark_recently_accessed(self, server_name: Optional[str], media_id: str) -> None:
127130 """Mark the given media as recently accessed.
128131
129132 Args:
130 server_name (str|None): Origin server of media, or None if local
131 media_id (str): The media ID of the content
133 server_name: Origin server of media, or None if local
134 media_id: The media ID of the content
132135 """
133136 if server_name:
134137 self.recently_accessed_remotes.add((server_name, media_id))
458461 def _get_thumbnail_requirements(self, media_type):
459462 return self.thumbnail_requirements.get(media_type, ())
460463
461 def _generate_thumbnail(self, thumbnailer, t_width, t_height, t_method, t_type):
464 def _generate_thumbnail(
465 self,
466 thumbnailer: Thumbnailer,
467 t_width: int,
468 t_height: int,
469 t_method: str,
470 t_type: str,
471 ) -> Optional[BytesIO]:
462472 m_width = thumbnailer.width
463473 m_height = thumbnailer.height
464474
469479 m_height,
470480 self.max_image_pixels,
471481 )
472 return
482 return None
473483
474484 if thumbnailer.transpose_method is not None:
475485 m_width, m_height = thumbnailer.transpose()
476486
477487 if t_method == "crop":
478 t_byte_source = thumbnailer.crop(t_width, t_height, t_type)
488 return thumbnailer.crop(t_width, t_height, t_type)
479489 elif t_method == "scale":
480490 t_width, t_height = thumbnailer.aspect(t_width, t_height)
481491 t_width = min(m_width, t_width)
482492 t_height = min(m_height, t_height)
483 t_byte_source = thumbnailer.scale(t_width, t_height, t_type)
484 else:
485 t_byte_source = None
486
487 return t_byte_source
493 return thumbnailer.scale(t_width, t_height, t_type)
494
495 return None
488496
489497 async def generate_local_exact_thumbnail(
490498 self,
775783
776784 return {"width": m_width, "height": m_height}
777785
778 async def delete_old_remote_media(self, before_ts):
786 async def delete_old_remote_media(self, before_ts: int) -> Dict[str, int]:
779787 old_media = await self.store.get_remote_media_before(before_ts)
780788
781789 deleted = 0
927935 within a given rectangle.
928936 """
929937
930 def __init__(self, hs):
938 def __init__(self, hs: "HomeServer"):
931939 # If we're not configured to use it, raise if we somehow got here.
932940 if not hs.config.can_load_media_repo:
933941 raise ConfigError("Synapse is not configured to use a media repo.")
00 # -*- coding: utf-8 -*-
1 # Copyright 2018 New Vecotr Ltd
1 # Copyright 2018-2021 The Matrix.org Foundation C.I.C.
22 #
33 # Licensed under the Apache License, Version 2.0 (the "License");
44 # you may not use this file except in compliance with the License.
1717 import shutil
1818 from typing import IO, TYPE_CHECKING, Any, Optional, Sequence
1919
20 from twisted.internet.defer import Deferred
21 from twisted.internet.interfaces import IConsumer
2022 from twisted.protocols.basic import FileSender
2123
2224 from synapse.logging.context import defer_to_thread, make_deferred_yieldable
269271 return self.filepaths.local_media_filepath_rel(file_info.file_id)
270272
271273
272 def _write_file_synchronously(source, dest):
274 def _write_file_synchronously(source: IO, dest: IO) -> None:
273275 """Write `source` to the file like `dest` synchronously. Should be called
274276 from a thread.
275277
285287 """Wraps an open file that can be sent to a request.
286288
287289 Args:
288 open_file (file): A file like object to be streamed ot the client,
290 open_file: A file like object to be streamed ot the client,
289291 is closed when finished streaming.
290292 """
291293
292 def __init__(self, open_file):
294 def __init__(self, open_file: IO):
293295 self.open_file = open_file
294296
295 def write_to_consumer(self, consumer):
297 def write_to_consumer(self, consumer: IConsumer) -> Deferred:
296298 return make_deferred_yieldable(
297299 FileSender().beginFileTransfer(self.open_file, consumer)
298300 )
00 # -*- coding: utf-8 -*-
11 # Copyright 2016 OpenMarket Ltd
2 # Copyright 2020-2021 The Matrix.org Foundation C.I.C.
23 #
34 # Licensed under the Apache License, Version 2.0 (the "License");
45 # you may not use this file except in compliance with the License.
1112 # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
1213 # See the License for the specific language governing permissions and
1314 # limitations under the License.
14
1515 import datetime
1616 import errno
1717 import fnmatch
2222 import shutil
2323 import sys
2424 import traceback
25 from typing import Dict, Optional
25 from typing import TYPE_CHECKING, Any, Dict, Generator, Iterable, Optional, Union
2626 from urllib import parse as urlparse
2727
2828 import attr
2929
3030 from twisted.internet.error import DNSLookupError
31 from twisted.web.http import Request
3132
3233 from synapse.api.errors import Codes, SynapseError
3334 from synapse.http.client import SimpleHttpClient
4041 from synapse.logging.context import make_deferred_yieldable, run_in_background
4142 from synapse.metrics.background_process_metrics import run_as_background_process
4243 from synapse.rest.media.v1._base import get_filename_from_headers
44 from synapse.rest.media.v1.media_storage import MediaStorage
4345 from synapse.util import json_encoder
4446 from synapse.util.async_helpers import ObservableDeferred
4547 from synapse.util.caches.expiringcache import ExpiringCache
4648 from synapse.util.stringutils import random_string
4749
4850 from ._base import FileInfo
51
52 if TYPE_CHECKING:
53 from lxml import etree
54
55 from synapse.app.homeserver import HomeServer
56 from synapse.rest.media.v1.media_repository import MediaRepository
4957
5058 logger = logging.getLogger(__name__)
5159
118126 class PreviewUrlResource(DirectServeJsonResource):
119127 isLeaf = True
120128
121 def __init__(self, hs, media_repo, media_storage):
129 def __init__(
130 self,
131 hs: "HomeServer",
132 media_repo: "MediaRepository",
133 media_storage: MediaStorage,
134 ):
122135 super().__init__()
123136
124137 self.auth = hs.get_auth()
165178 self._start_expire_url_cache_data, 10 * 1000
166179 )
167180
168 async def _async_render_OPTIONS(self, request):
181 async def _async_render_OPTIONS(self, request: Request) -> None:
169182 request.setHeader(b"Allow", b"OPTIONS, GET")
170183 respond_with_json(request, 200, {}, send_cors=True)
171184
172 async def _async_render_GET(self, request):
185 async def _async_render_GET(self, request: Request) -> None:
173186
174187 # XXX: if get_user_by_req fails, what should we do in an async render?
175188 requester = await self.auth.get_user_by_req(request)
449462 logger.warning("Error downloading oEmbed metadata from %s: %r", url, e)
450463 raise OEmbedError() from e
451464
452 async def _download_url(self, url: str, user):
465 async def _download_url(self, url: str, user: str) -> Dict[str, Any]:
453466 # TODO: we should probably honour robots.txt... except in practice
454467 # we're most likely being explicitly triggered by a human rather than a
455468 # bot, so are we really a robot?
579592 "expire_url_cache_data", self._expire_url_cache_data
580593 )
581594
582 async def _expire_url_cache_data(self):
595 async def _expire_url_cache_data(self) -> None:
583596 """Clean up expired url cache content, media and thumbnails.
584597 """
585598 # TODO: Delete from backup media store
675688 logger.debug("No media removed from url cache")
676689
677690
678 def decode_and_calc_og(body, media_uri, request_encoding=None) -> Dict[str, str]:
691 def decode_and_calc_og(
692 body: bytes, media_uri: str, request_encoding: Optional[str] = None
693 ) -> Dict[str, Optional[str]]:
679694 # If there's no body, nothing useful is going to be found.
680695 if not body:
681696 return {}
696711 return og
697712
698713
699 def _calc_og(tree, media_uri):
714 def _calc_og(tree, media_uri: str) -> Dict[str, Optional[str]]:
700715 # suck our tree into lxml and define our OG response.
701716
702717 # if we see any image URLs in the OG response, then spider them
800815 for el in _iterate_over_text(tree.find("body"), *TAGS_TO_REMOVE)
801816 )
802817 og["og:description"] = summarize_paragraphs(text_nodes)
803 else:
818 elif og["og:description"]:
819 # This must be a non-empty string at this point.
820 assert isinstance(og["og:description"], str)
804821 og["og:description"] = summarize_paragraphs([og["og:description"]])
805822
806823 # TODO: delete the url downloads to stop diskfilling,
808825 return og
809826
810827
811 def _iterate_over_text(tree, *tags_to_ignore):
828 def _iterate_over_text(
829 tree, *tags_to_ignore: Iterable[Union[str, "etree.Comment"]]
830 ) -> Generator[str, None, None]:
812831 """Iterate over the tree returning text nodes in a depth first fashion,
813832 skipping text nodes inside certain tags.
814833 """
842861 )
843862
844863
845 def _rebase_url(url, base):
846 base = list(urlparse.urlparse(base))
847 url = list(urlparse.urlparse(url))
848 if not url[0]: # fix up schema
849 url[0] = base[0] or "http"
850 if not url[1]: # fix up hostname
851 url[1] = base[1]
852 if not url[2].startswith("/"):
853 url[2] = re.sub(r"/[^/]+$", "/", base[2]) + url[2]
854 return urlparse.urlunparse(url)
855
856
857 def _is_media(content_type):
858 if content_type.lower().startswith("image/"):
859 return True
860
861
862 def _is_html(content_type):
864 def _rebase_url(url: str, base: str) -> str:
865 base_parts = list(urlparse.urlparse(base))
866 url_parts = list(urlparse.urlparse(url))
867 if not url_parts[0]: # fix up schema
868 url_parts[0] = base_parts[0] or "http"
869 if not url_parts[1]: # fix up hostname
870 url_parts[1] = base_parts[1]
871 if not url_parts[2].startswith("/"):
872 url_parts[2] = re.sub(r"/[^/]+$", "/", base_parts[2]) + url_parts[2]
873 return urlparse.urlunparse(url_parts)
874
875
876 def _is_media(content_type: str) -> bool:
877 return content_type.lower().startswith("image/")
878
879
880 def _is_html(content_type: str) -> bool:
863881 content_type = content_type.lower()
864 if content_type.startswith("text/html") or content_type.startswith(
882 return content_type.startswith("text/html") or content_type.startswith(
865883 "application/xhtml"
866 ):
867 return True
868
869
870 def summarize_paragraphs(text_nodes, min_size=200, max_size=500):
884 )
885
886
887 def summarize_paragraphs(
888 text_nodes: Iterable[str], min_size: int = 200, max_size: int = 500
889 ) -> Optional[str]:
871890 # Try to get a summary of between 200 and 500 words, respecting
872891 # first paragraph and then word boundaries.
873892 # TODO: Respect sentences?
00 # -*- coding: utf-8 -*-
1 # Copyright 2018 New Vector Ltd
1 # Copyright 2018-2021 The Matrix.org Foundation C.I.C.
22 #
33 # Licensed under the Apache License, Version 2.0 (the "License");
44 # you may not use this file except in compliance with the License.
1212 # See the License for the specific language governing permissions and
1313 # limitations under the License.
1414
15 import abc
1516 import logging
1617 import os
1718 import shutil
18 from typing import Optional
19 from typing import TYPE_CHECKING, Optional
1920
2021 from synapse.config._base import Config
2122 from synapse.logging.context import defer_to_thread, run_in_background
2627
2728 logger = logging.getLogger(__name__)
2829
30 if TYPE_CHECKING:
31 from synapse.app.homeserver import HomeServer
2932
30 class StorageProvider:
33
34 class StorageProvider(metaclass=abc.ABCMeta):
3135 """A storage provider is a service that can store uploaded media and
3236 retrieve them.
3337 """
3438
35 async def store_file(self, path: str, file_info: FileInfo):
39 @abc.abstractmethod
40 async def store_file(self, path: str, file_info: FileInfo) -> None:
3641 """Store the file described by file_info. The actual contents can be
3742 retrieved by reading the file in file_info.upload_path.
3843
4146 file_info: The metadata of the file.
4247 """
4348
49 @abc.abstractmethod
4450 async def fetch(self, path: str, file_info: FileInfo) -> Optional[Responder]:
4551 """Attempt to fetch the file described by file_info and stream it
4652 into writer.
7783 self.store_synchronous = store_synchronous
7884 self.store_remote = store_remote
7985
80 def __str__(self):
86 def __str__(self) -> str:
8187 return "StorageProviderWrapper[%s]" % (self.backend,)
8288
83 async def store_file(self, path, file_info):
89 async def store_file(self, path: str, file_info: FileInfo) -> None:
8490 if not file_info.server_name and not self.store_local:
8591 return None
8692
9096 if self.store_synchronous:
9197 # store_file is supposed to return an Awaitable, but guard
9298 # against improper implementations.
93 return await maybe_awaitable(self.backend.store_file(path, file_info))
99 await maybe_awaitable(self.backend.store_file(path, file_info)) # type: ignore
94100 else:
95101 # TODO: Handle errors.
96102 async def store():
102108 logger.exception("Error storing file")
103109
104110 run_in_background(store)
105 return None
106111
107 async def fetch(self, path, file_info):
112 async def fetch(self, path: str, file_info: FileInfo) -> Optional[Responder]:
108113 # store_file is supposed to return an Awaitable, but guard
109114 # against improper implementations.
110115 return await maybe_awaitable(self.backend.fetch(path, file_info))
114119 """A storage provider that stores files in a directory on a filesystem.
115120
116121 Args:
117 hs (HomeServer)
122 hs
118123 config: The config returned by `parse_config`.
119124 """
120125
121 def __init__(self, hs, config):
126 def __init__(self, hs: "HomeServer", config: str):
122127 self.hs = hs
123128 self.cache_directory = hs.config.media_store_path
124129 self.base_directory = config
126131 def __str__(self):
127132 return "FileStorageProviderBackend[%s]" % (self.base_directory,)
128133
129 async def store_file(self, path, file_info):
134 async def store_file(self, path: str, file_info: FileInfo) -> None:
130135 """See StorageProvider.store_file"""
131136
132137 primary_fname = os.path.join(self.cache_directory, path)
136141 if not os.path.exists(dirname):
137142 os.makedirs(dirname)
138143
139 return await defer_to_thread(
144 await defer_to_thread(
140145 self.hs.get_reactor(), shutil.copyfile, primary_fname, backup_fname
141146 )
142147
143 async def fetch(self, path, file_info):
148 async def fetch(self, path: str, file_info: FileInfo) -> Optional[Responder]:
144149 """See StorageProvider.fetch"""
145150
146151 backup_fname = os.path.join(self.base_directory, path)
147152 if os.path.isfile(backup_fname):
148153 return FileResponder(open(backup_fname, "rb"))
149154
155 return None
156
150157 @staticmethod
151 def parse_config(config):
158 def parse_config(config: dict) -> str:
152159 """Called on startup to parse config supplied. This should parse
153160 the config and raise if there is a problem.
154161
00 # -*- coding: utf-8 -*-
1 # Copyright 2014 - 2016 OpenMarket Ltd
1 # Copyright 2014-2016 OpenMarket Ltd
2 # Copyright 2020-2021 The Matrix.org Foundation C.I.C.
23 #
34 # Licensed under the Apache License, Version 2.0 (the "License");
45 # you may not use this file except in compliance with the License.
1415
1516
1617 import logging
18 from typing import TYPE_CHECKING
19
20 from twisted.web.http import Request
1721
1822 from synapse.api.errors import SynapseError
1923 from synapse.http.server import DirectServeJsonResource, set_cors_headers
2024 from synapse.http.servlet import parse_integer, parse_string
25 from synapse.rest.media.v1.media_storage import MediaStorage
2126
2227 from ._base import (
2328 FileInfo,
2732 respond_with_responder,
2833 )
2934
35 if TYPE_CHECKING:
36 from synapse.app.homeserver import HomeServer
37 from synapse.rest.media.v1.media_repository import MediaRepository
38
3039 logger = logging.getLogger(__name__)
3140
3241
3342 class ThumbnailResource(DirectServeJsonResource):
3443 isLeaf = True
3544
36 def __init__(self, hs, media_repo, media_storage):
45 def __init__(
46 self,
47 hs: "HomeServer",
48 media_repo: "MediaRepository",
49 media_storage: MediaStorage,
50 ):
3751 super().__init__()
3852
3953 self.store = hs.get_datastore()
4256 self.dynamic_thumbnails = hs.config.dynamic_thumbnails
4357 self.server_name = hs.hostname
4458
45 async def _async_render_GET(self, request):
59 async def _async_render_GET(self, request: Request) -> None:
4660 set_cors_headers(request)
4761 server_name, media_id, _ = parse_media_id(request)
4862 width = parse_integer(request, "width", required=True)
7286 self.media_repo.mark_recently_accessed(server_name, media_id)
7387
7488 async def _respond_local_thumbnail(
75 self, request, media_id, width, height, method, m_type
76 ):
89 self,
90 request: Request,
91 media_id: str,
92 width: int,
93 height: int,
94 method: str,
95 m_type: str,
96 ) -> None:
7797 media_info = await self.store.get_local_media(media_id)
7898
7999 if not media_info:
113133
114134 async def _select_or_generate_local_thumbnail(
115135 self,
116 request,
117 media_id,
118 desired_width,
119 desired_height,
120 desired_method,
121 desired_type,
122 ):
136 request: Request,
137 media_id: str,
138 desired_width: int,
139 desired_height: int,
140 desired_method: str,
141 desired_type: str,
142 ) -> None:
123143 media_info = await self.store.get_local_media(media_id)
124144
125145 if not media_info:
177197
178198 async def _select_or_generate_remote_thumbnail(
179199 self,
180 request,
181 server_name,
182 media_id,
183 desired_width,
184 desired_height,
185 desired_method,
186 desired_type,
187 ):
200 request: Request,
201 server_name: str,
202 media_id: str,
203 desired_width: int,
204 desired_height: int,
205 desired_method: str,
206 desired_type: str,
207 ) -> None:
188208 media_info = await self.media_repo.get_remote_media_info(server_name, media_id)
189209
190210 thumbnail_infos = await self.store.get_remote_media_thumbnails(
238258 raise SynapseError(400, "Failed to generate thumbnail.")
239259
240260 async def _respond_remote_thumbnail(
241 self, request, server_name, media_id, width, height, method, m_type
242 ):
261 self,
262 request: Request,
263 server_name: str,
264 media_id: str,
265 width: int,
266 height: int,
267 method: str,
268 m_type: str,
269 ) -> None:
243270 # TODO: Don't download the whole remote file
244271 # We should proxy the thumbnail from the remote server instead of
245272 # downloading the remote file and generating our own thumbnails.
274301
275302 def _select_thumbnail(
276303 self,
277 desired_width,
278 desired_height,
279 desired_method,
280 desired_type,
304 desired_width: int,
305 desired_height: int,
306 desired_method: str,
307 desired_type: str,
281308 thumbnail_infos,
282 ):
309 ) -> dict:
283310 d_w = desired_width
284311 d_h = desired_height
285312
00 # -*- coding: utf-8 -*-
11 # Copyright 2014-2016 OpenMarket Ltd
2 # Copyright 2020-2021 The Matrix.org Foundation C.I.C.
23 #
34 # Licensed under the Apache License, Version 2.0 (the "License");
45 # you may not use this file except in compliance with the License.
1314 # limitations under the License.
1415 import logging
1516 from io import BytesIO
17 from typing import Tuple
1618
1719 from PIL import Image
1820
3840
3941 FORMATS = {"image/jpeg": "JPEG", "image/png": "PNG"}
4042
41 def __init__(self, input_path):
43 def __init__(self, input_path: str):
4244 try:
4345 self.image = Image.open(input_path)
4446 except OSError as e:
5860 # A lot of parsing errors can happen when parsing EXIF
5961 logger.info("Error parsing image EXIF information: %s", e)
6062
61 def transpose(self):
63 def transpose(self) -> Tuple[int, int]:
6264 """Transpose the image using its EXIF Orientation tag
6365
6466 Returns:
65 Tuple[int, int]: (width, height) containing the new image size in pixels.
67 A tuple containing the new image size in pixels as (width, height).
6668 """
6769 if self.transpose_method is not None:
6870 self.image = self.image.transpose(self.transpose_method)
7274 self.image.info["exif"] = None
7375 return self.image.size
7476
75 def aspect(self, max_width, max_height):
77 def aspect(self, max_width: int, max_height: int) -> Tuple[int, int]:
7678 """Calculate the largest size that preserves aspect ratio which
7779 fits within the given rectangle::
7880
9092 else:
9193 return (max_height * self.width) // self.height, max_height
9294
93 def _resize(self, width, height):
95 def _resize(self, width: int, height: int) -> Image:
9496 # 1-bit or 8-bit color palette images need converting to RGB
9597 # otherwise they will be scaled using nearest neighbour which
9698 # looks awful
98100 self.image = self.image.convert("RGB")
99101 return self.image.resize((width, height), Image.ANTIALIAS)
100102
101 def scale(self, width, height, output_type):
103 def scale(self, width: int, height: int, output_type: str) -> BytesIO:
102104 """Rescales the image to the given dimensions.
103105
104106 Returns:
107109 scaled = self._resize(width, height)
108110 return self._encode_image(scaled, output_type)
109111
110 def crop(self, width, height, output_type):
112 def crop(self, width: int, height: int, output_type: str) -> BytesIO:
111113 """Rescales and crops the image to the given dimensions preserving
112114 aspect::
113115 (w_in / h_in) = (w_scaled / h_scaled)
135137 cropped = scaled_image.crop((crop_left, 0, crop_right, height))
136138 return self._encode_image(cropped, output_type)
137139
138 def _encode_image(self, output_image, output_type):
140 def _encode_image(self, output_image: Image, output_type: str) -> BytesIO:
139141 output_bytes_io = BytesIO()
140142 fmt = self.FORMATS[output_type]
141143 if fmt == "JPEG":
00 # -*- coding: utf-8 -*-
11 # Copyright 2014-2016 OpenMarket Ltd
2 # Copyright 2020-2021 The Matrix.org Foundation C.I.C.
23 #
34 # Licensed under the Apache License, Version 2.0 (the "License");
45 # you may not use this file except in compliance with the License.
1314 # limitations under the License.
1415
1516 import logging
17 from typing import TYPE_CHECKING
18
19 from twisted.web.http import Request
1620
1721 from synapse.api.errors import Codes, SynapseError
1822 from synapse.http.server import DirectServeJsonResource, respond_with_json
1923 from synapse.http.servlet import parse_string
24
25 if TYPE_CHECKING:
26 from synapse.app.homeserver import HomeServer
27 from synapse.rest.media.v1.media_repository import MediaRepository
2028
2129 logger = logging.getLogger(__name__)
2230
2432 class UploadResource(DirectServeJsonResource):
2533 isLeaf = True
2634
27 def __init__(self, hs, media_repo):
35 def __init__(self, hs: "HomeServer", media_repo: "MediaRepository"):
2836 super().__init__()
2937
3038 self.media_repo = media_repo
3644 self.max_upload_size = hs.config.max_upload_size
3745 self.clock = hs.get_clock()
3846
39 async def _async_render_OPTIONS(self, request):
47 async def _async_render_OPTIONS(self, request: Request) -> None:
4048 respond_with_json(request, 200, {}, send_cors=True)
4149
42 async def _async_render_POST(self, request):
50 async def _async_render_POST(self, request: Request) -> None:
4351 requester = await self.auth.get_user_by_req(request)
4452 # TODO: The checks here are a bit late. The content will have
4553 # already been uploaded to a tmp file at this point
0 # -*- coding: utf-8 -*-
1 # Copyright 2021 The Matrix.org Foundation C.I.C.
2 #
3 # Licensed under the Apache License, Version 2.0 (the "License");
4 # you may not use this file except in compliance with the License.
5 # You may obtain a copy of the License at
6 #
7 # http://www.apache.org/licenses/LICENSE-2.0
8 #
9 # Unless required by applicable law or agreed to in writing, software
10 # distributed under the License is distributed on an "AS IS" BASIS,
11 # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 # See the License for the specific language governing permissions and
13 # limitations under the License.
14 import logging
15 from typing import TYPE_CHECKING
16
17 from synapse.http.server import (
18 DirectServeHtmlResource,
19 finish_request,
20 respond_with_html,
21 )
22 from synapse.http.servlet import parse_string
23 from synapse.http.site import SynapseRequest
24
25 if TYPE_CHECKING:
26 from synapse.server import HomeServer
27
28 logger = logging.getLogger(__name__)
29
30
31 class PickIdpResource(DirectServeHtmlResource):
32 """IdP picker resource.
33
34 This resource gets mounted under /_synapse/client/pick_idp. It serves an HTML page
35 which prompts the user to choose an Identity Provider from the list.
36 """
37
38 def __init__(self, hs: "HomeServer"):
39 super().__init__()
40 self._sso_handler = hs.get_sso_handler()
41 self._sso_login_idp_picker_template = (
42 hs.config.sso.sso_login_idp_picker_template
43 )
44 self._server_name = hs.hostname
45
46 async def _async_render_GET(self, request: SynapseRequest) -> None:
47 client_redirect_url = parse_string(
48 request, "redirectUrl", required=True, encoding="utf-8"
49 )
50 idp = parse_string(request, "idp", required=False)
51
52 # if we need to pick an IdP, do so
53 if not idp:
54 return await self._serve_id_picker(request, client_redirect_url)
55
56 # otherwise, redirect to the IdP's redirect URI
57 providers = self._sso_handler.get_identity_providers()
58 auth_provider = providers.get(idp)
59 if not auth_provider:
60 logger.info("Unknown idp %r", idp)
61 self._sso_handler.render_error(
62 request, "unknown_idp", "Unknown identity provider ID"
63 )
64 return
65
66 sso_url = await auth_provider.handle_redirect_request(
67 request, client_redirect_url.encode("utf8")
68 )
69 logger.info("Redirecting to %s", sso_url)
70 request.redirect(sso_url)
71 finish_request(request)
72
73 async def _serve_id_picker(
74 self, request: SynapseRequest, client_redirect_url: str
75 ) -> None:
76 # otherwise, serve up the IdP picker
77 providers = self._sso_handler.get_identity_providers()
78 html = self._sso_login_idp_picker_template.render(
79 redirect_url=client_redirect_url,
80 server_name=self._server_name,
81 providers=providers.values(),
82 )
83 respond_with_html(request, 200, html)
3333 self._config = hs.config
3434
3535 def get_well_known(self):
36 # if we don't have a public_baseurl, we can't help much here.
37 if self._config.public_baseurl is None:
38 return None
39
4036 result = {"m.homeserver": {"base_url": self._config.public_baseurl}}
4137
4238 if self._config.default_identity_server:
5454 from synapse.federation.transport.client import TransportLayerClient
5555 from synapse.groups.attestations import GroupAttestationSigning, GroupAttestionRenewer
5656 from synapse.groups.groups_server import GroupsServerHandler, GroupsServerWorkerHandler
57 from synapse.handlers.account_data import AccountDataHandler
5758 from synapse.handlers.account_validity import AccountValidityHandler
5859 from synapse.handlers.acme import AcmeHandler
5960 from synapse.handlers.admin import AdminHandler
282283 """
283284 return self._reactor
284285
285 def get_ip_from_request(self, request) -> str:
286 # X-Forwarded-For is handled by our custom request type.
287 return request.getClientIP()
288
289286 def is_mine(self, domain_specific_string: DomainSpecificString) -> bool:
290287 return domain_specific_string.domain == self.hostname
291288
504501 return InitialSyncHandler(self)
505502
506503 @cache_in_self
507 def get_profile_handler(self):
504 def get_profile_handler(self) -> ProfileHandler:
508505 return ProfileHandler(self)
509506
510507 @cache_in_self
714711 def get_module_api(self) -> ModuleApi:
715712 return ModuleApi(self, self.get_auth_handler())
716713
714 @cache_in_self
715 def get_account_data_handler(self) -> AccountDataHandler:
716 return AccountDataHandler(self)
717
717718 async def remove_pusher(self, app_id: str, push_key: str, user_id: str):
718719 return await self.get_pusherpool().remove_pusher(app_id, push_key, user_id)
719720
4141 self._auth = hs.get_auth()
4242 self._config = hs.config
4343 self._resouce_limited = False
44 self._account_data_handler = hs.get_account_data_handler()
4445 self._message_handler = hs.get_message_handler()
4546 self._state = hs.get_state_handler()
4647
176177 # tag already present, nothing to do here
177178 need_to_set_tag = False
178179 if need_to_set_tag:
179 max_id = await self._store.add_tag_to_room(
180 max_id = await self._account_data_handler.add_tag_to_room(
180181 user_id, room_id, SERVER_NOTICE_ROOM_TAG, {}
181182 )
182183 self._notifier.on_new_event("account_data_key", max_id, users=[user_id])
3434
3535 self._store = hs.get_datastore()
3636 self._config = hs.config
37 self._account_data_handler = hs.get_account_data_handler()
3738 self._room_creation_handler = hs.get_room_creation_handler()
3839 self._room_member_handler = hs.get_room_member_handler()
3940 self._event_creation_handler = hs.get_event_creation_handler()
162163 )
163164 room_id = info["room_id"]
164165
165 max_id = await self._store.add_tag_to_room(
166 max_id = await self._account_data_handler.add_tag_to_room(
166167 user_id, room_id, SERVER_NOTICE_ROOM_TAG, {}
167168 )
168169 self._notifier.on_new_event("account_data_key", max_id, users=[user_id])
2828 form {
2929 text-align: center;
3030 margin: 10px 0 0 0;
31 }
32
33 ul.radiobuttons {
34 text-align: left;
35 list-style: none;
3136 }
3237
3338 /*
4141 from synapse.config.database import DatabaseConnectionConfig
4242 from synapse.logging.context import (
4343 LoggingContext,
44 LoggingContextOrSentinel,
4544 current_context,
4645 make_deferred_yieldable,
4746 )
4948 from synapse.storage.background_updates import BackgroundUpdater
5049 from synapse.storage.engines import BaseDatabaseEngine, PostgresEngine, Sqlite3Engine
5150 from synapse.storage.types import Connection, Cursor
51 from synapse.storage.util.sequence import build_sequence_generator
5252 from synapse.types import Collection
5353
5454 # python 3 does not have a maximum int value
179179 _CallbackListEntry = Tuple["Callable[..., None]", Iterable[Any], Dict[str, Any]]
180180
181181
182 R = TypeVar("R")
183
184
182185 class LoggingTransaction:
183186 """An object that almost-transparently proxies for the 'txn' object
184187 passed to the constructor. Adds logging and metrics to the .execute()
266269 for val in args:
267270 self.execute(sql, val)
268271
272 def execute_values(self, sql: str, *args: Any) -> List[Tuple]:
273 """Corresponds to psycopg2.extras.execute_values. Only available when
274 using postgres.
275
276 Always sets fetch=True when caling `execute_values`, so will return the
277 results.
278 """
279 assert isinstance(self.database_engine, PostgresEngine)
280 from psycopg2.extras import execute_values # type: ignore
281
282 return self._do_execute(
283 lambda *x: execute_values(self.txn, *x, fetch=True), sql, *args
284 )
285
269286 def execute(self, sql: str, *args: Any) -> None:
270287 self._do_execute(self.txn.execute, sql, *args)
271288
276293 "Strip newlines out of SQL so that the loggers in the DB are on one line"
277294 return " ".join(line.strip() for line in sql.splitlines() if line.strip())
278295
279 def _do_execute(self, func, sql: str, *args: Any) -> None:
296 def _do_execute(self, func: Callable[..., R], sql: str, *args: Any) -> R:
280297 sql = self._make_sql_one_line(sql)
281298
282299 # TODO(paul): Maybe use 'info' and 'debug' for values?
347364 return top_n_counters
348365
349366
350 R = TypeVar("R")
351
352
353367 class DatabasePool:
354368 """Wraps a single physical database and connection pool.
355369
397411 "upsert_safety_check",
398412 self._check_safe_to_upsert,
399413 )
414
415 # We define this sequence here so that it can be referenced from both
416 # the DataStore and PersistEventStore.
417 def get_chain_id_txn(txn):
418 txn.execute("SELECT COALESCE(max(chain_id), 0) FROM event_auth_chains")
419 return txn.fetchone()[0]
420
421 self.event_chain_id_gen = build_sequence_generator(
422 engine, get_chain_id_txn, "event_auth_chain_id"
423 )
400424
401425 def is_running(self) -> bool:
402426 """Is the database pool currently running
670694 Returns:
671695 The result of func
672696 """
673 parent_context = current_context() # type: Optional[LoggingContextOrSentinel]
674 if not parent_context:
697 curr_context = current_context()
698 if not curr_context:
675699 logger.warning(
676700 "Starting db connection from sentinel context: metrics will be lost"
677701 )
678702 parent_context = None
703 else:
704 assert isinstance(curr_context, LoggingContext)
705 parent_context = curr_context
679706
680707 start_time = monotonic_time()
681708
126126 self._presence_id_gen = StreamIdGenerator(
127127 db_conn, "presence_stream", "stream_id"
128128 )
129 self._device_inbox_id_gen = StreamIdGenerator(
130 db_conn, "device_inbox", "stream_id"
131 )
132129 self._public_room_id_gen = StreamIdGenerator(
133130 db_conn, "public_room_list_stream", "stream_id"
134131 )
162159 database,
163160 stream_name="caches",
164161 instance_name=hs.get_instance_name(),
165 table="cache_invalidation_stream_by_instance",
166 instance_column="instance_name",
167 id_column="stream_id",
162 tables=[
163 (
164 "cache_invalidation_stream_by_instance",
165 "instance_name",
166 "stream_id",
167 )
168 ],
168169 sequence_name="cache_invalidation_stream_seq",
169170 writers=[],
170171 )
186187 "PresenceStreamChangeCache",
187188 min_presence_val,
188189 prefilled_cache=presence_cache_prefill,
189 )
190
191 max_device_inbox_id = self._device_inbox_id_gen.get_current_token()
192 device_inbox_prefill, min_device_inbox_id = self.db_pool.get_cache_dict(
193 db_conn,
194 "device_inbox",
195 entity_column="user_id",
196 stream_column="stream_id",
197 max_value=max_device_inbox_id,
198 limit=1000,
199 )
200 self._device_inbox_stream_cache = StreamChangeCache(
201 "DeviceInboxStreamChangeCache",
202 min_device_inbox_id,
203 prefilled_cache=device_inbox_prefill,
204 )
205 # The federation outbox and the local device inbox uses the same
206 # stream_id generator.
207 device_outbox_prefill, min_device_outbox_id = self.db_pool.get_cache_dict(
208 db_conn,
209 "device_federation_outbox",
210 entity_column="destination",
211 stream_column="stream_id",
212 max_value=max_device_inbox_id,
213 limit=1000,
214 )
215 self._device_federation_outbox_stream_cache = StreamChangeCache(
216 "DeviceFederationOutboxStreamChangeCache",
217 min_device_outbox_id,
218 prefilled_cache=device_outbox_prefill,
219190 )
220191
221192 device_list_max = self._device_list_id_gen.get_current_token()
1313 # See the License for the specific language governing permissions and
1414 # limitations under the License.
1515
16 import abc
1716 import logging
18 from typing import Dict, List, Optional, Tuple
17 from typing import Dict, List, Optional, Set, Tuple
1918
2019 from synapse.api.constants import AccountDataTypes
20 from synapse.replication.slave.storage._slaved_id_tracker import SlavedIdTracker
21 from synapse.replication.tcp.streams import AccountDataStream, TagAccountDataStream
2122 from synapse.storage._base import SQLBaseStore, db_to_json
2223 from synapse.storage.database import DatabasePool
23 from synapse.storage.util.id_generators import StreamIdGenerator
24 from synapse.storage.engines import PostgresEngine
25 from synapse.storage.util.id_generators import MultiWriterIdGenerator, StreamIdGenerator
2426 from synapse.types import JsonDict
2527 from synapse.util import json_encoder
26 from synapse.util.caches.descriptors import _CacheContext, cached
28 from synapse.util.caches.descriptors import cached
2729 from synapse.util.caches.stream_change_cache import StreamChangeCache
2830
2931 logger = logging.getLogger(__name__)
3032
3133
32 # The ABCMeta metaclass ensures that it cannot be instantiated without
33 # the abstract methods being implemented.
34 class AccountDataWorkerStore(SQLBaseStore, metaclass=abc.ABCMeta):
34 class AccountDataWorkerStore(SQLBaseStore):
3535 """This is an abstract base class where subclasses must implement
3636 `get_max_account_data_stream_id` which can be called in the initializer.
3737 """
3838
3939 def __init__(self, database: DatabasePool, db_conn, hs):
40 self._instance_name = hs.get_instance_name()
41
42 if isinstance(database.engine, PostgresEngine):
43 self._can_write_to_account_data = (
44 self._instance_name in hs.config.worker.writers.account_data
45 )
46
47 self._account_data_id_gen = MultiWriterIdGenerator(
48 db_conn=db_conn,
49 db=database,
50 stream_name="account_data",
51 instance_name=self._instance_name,
52 tables=[
53 ("room_account_data", "instance_name", "stream_id"),
54 ("room_tags_revisions", "instance_name", "stream_id"),
55 ("account_data", "instance_name", "stream_id"),
56 ],
57 sequence_name="account_data_sequence",
58 writers=hs.config.worker.writers.account_data,
59 )
60 else:
61 self._can_write_to_account_data = True
62
63 # We shouldn't be running in worker mode with SQLite, but its useful
64 # to support it for unit tests.
65 #
66 # If this process is the writer than we need to use
67 # `StreamIdGenerator`, otherwise we use `SlavedIdTracker` which gets
68 # updated over replication. (Multiple writers are not supported for
69 # SQLite).
70 if hs.get_instance_name() in hs.config.worker.writers.account_data:
71 self._account_data_id_gen = StreamIdGenerator(
72 db_conn,
73 "room_account_data",
74 "stream_id",
75 extra_tables=[("room_tags_revisions", "stream_id")],
76 )
77 else:
78 self._account_data_id_gen = SlavedIdTracker(
79 db_conn,
80 "room_account_data",
81 "stream_id",
82 extra_tables=[("room_tags_revisions", "stream_id")],
83 )
84
4085 account_max = self.get_max_account_data_stream_id()
4186 self._account_data_stream_cache = StreamChangeCache(
4287 "AccountDataAndTagsChangeCache", account_max
4489
4590 super().__init__(database, db_conn, hs)
4691
47 @abc.abstractmethod
48 def get_max_account_data_stream_id(self):
92 def get_max_account_data_stream_id(self) -> int:
4993 """Get the current max stream ID for account data stream
5094
5195 Returns:
5296 int
5397 """
54 raise NotImplementedError()
98 return self._account_data_id_gen.get_current_token()
5599
56100 @cached()
57101 async def get_account_data_for_user(
286330 "get_updated_account_data_for_user", get_updated_account_data_for_user_txn
287331 )
288332
289 @cached(num_args=2, cache_context=True, max_entries=5000)
290 async def is_ignored_by(
291 self, ignored_user_id: str, ignorer_user_id: str, cache_context: _CacheContext
292 ) -> bool:
293 ignored_account_data = await self.get_global_account_data_by_type_for_user(
294 AccountDataTypes.IGNORED_USER_LIST,
295 ignorer_user_id,
296 on_invalidate=cache_context.invalidate,
297 )
298 if not ignored_account_data:
299 return False
300
301 try:
302 return ignored_user_id in ignored_account_data.get("ignored_users", {})
303 except TypeError:
304 # The type of the ignored_users field is invalid.
305 return False
306
307
308 class AccountDataStore(AccountDataWorkerStore):
309 def __init__(self, database: DatabasePool, db_conn, hs):
310 self._account_data_id_gen = StreamIdGenerator(
311 db_conn,
312 "account_data_max_stream_id",
313 "stream_id",
314 extra_tables=[
315 ("room_account_data", "stream_id"),
316 ("room_tags_revisions", "stream_id"),
317 ],
318 )
319
320 super().__init__(database, db_conn, hs)
321
322 def get_max_account_data_stream_id(self) -> int:
323 """Get the current max stream id for the private user data stream
324
325 Returns:
326 The maximum stream ID.
327 """
328 return self._account_data_id_gen.get_current_token()
333 @cached(max_entries=5000, iterable=True)
334 async def ignored_by(self, user_id: str) -> Set[str]:
335 """
336 Get users which ignore the given user.
337
338 Params:
339 user_id: The user ID which might be ignored.
340
341 Return:
342 The user IDs which ignore the given user.
343 """
344 return set(
345 await self.db_pool.simple_select_onecol(
346 table="ignored_users",
347 keyvalues={"ignored_user_id": user_id},
348 retcol="ignorer_user_id",
349 desc="ignored_by",
350 )
351 )
352
353 def process_replication_rows(self, stream_name, instance_name, token, rows):
354 if stream_name == TagAccountDataStream.NAME:
355 self._account_data_id_gen.advance(instance_name, token)
356 for row in rows:
357 self.get_tags_for_user.invalidate((row.user_id,))
358 self._account_data_stream_cache.entity_has_changed(row.user_id, token)
359 elif stream_name == AccountDataStream.NAME:
360 self._account_data_id_gen.advance(instance_name, token)
361 for row in rows:
362 if not row.room_id:
363 self.get_global_account_data_by_type_for_user.invalidate(
364 (row.data_type, row.user_id)
365 )
366 self.get_account_data_for_user.invalidate((row.user_id,))
367 self.get_account_data_for_room.invalidate((row.user_id, row.room_id))
368 self.get_account_data_for_room_and_type.invalidate(
369 (row.user_id, row.room_id, row.data_type)
370 )
371 self._account_data_stream_cache.entity_has_changed(row.user_id, token)
372 return super().process_replication_rows(stream_name, instance_name, token, rows)
329373
330374 async def add_account_data_to_room(
331375 self, user_id: str, room_id: str, account_data_type: str, content: JsonDict
341385 Returns:
342386 The maximum stream ID.
343387 """
388 assert self._can_write_to_account_data
389
344390 content_json = json_encoder.encode(content)
345391
346392 async with self._account_data_id_gen.get_next() as next_id:
359405 lock=False,
360406 )
361407
362 # it's theoretically possible for the above to succeed and the
363 # below to fail - in which case we might reuse a stream id on
364 # restart, and the above update might not get propagated. That
365 # doesn't sound any worse than the whole update getting lost,
366 # which is what would happen if we combined the two into one
367 # transaction.
368 await self._update_max_stream_id(next_id)
369
370408 self._account_data_stream_cache.entity_has_changed(user_id, next_id)
371409 self.get_account_data_for_user.invalidate((user_id,))
372410 self.get_account_data_for_room.invalidate((user_id, room_id))
389427 Returns:
390428 The maximum stream ID.
391429 """
392 content_json = json_encoder.encode(content)
430 assert self._can_write_to_account_data
393431
394432 async with self._account_data_id_gen.get_next() as next_id:
395 # no need to lock here as account_data has a unique constraint on
396 # (user_id, account_data_type) so simple_upsert will retry if
397 # there is a conflict.
398 await self.db_pool.simple_upsert(
399 desc="add_user_account_data",
400 table="account_data",
401 keyvalues={"user_id": user_id, "account_data_type": account_data_type},
402 values={"stream_id": next_id, "content": content_json},
403 lock=False,
404 )
405
406 # it's theoretically possible for the above to succeed and the
407 # below to fail - in which case we might reuse a stream id on
408 # restart, and the above update might not get propagated. That
409 # doesn't sound any worse than the whole update getting lost,
410 # which is what would happen if we combined the two into one
411 # transaction.
412 #
413 # Note: This is only here for backwards compat to allow admins to
414 # roll back to a previous Synapse version. Next time we update the
415 # database version we can remove this table.
416 await self._update_max_stream_id(next_id)
433 await self.db_pool.runInteraction(
434 "add_user_account_data",
435 self._add_account_data_for_user,
436 next_id,
437 user_id,
438 account_data_type,
439 content,
440 )
417441
418442 self._account_data_stream_cache.entity_has_changed(user_id, next_id)
419443 self.get_account_data_for_user.invalidate((user_id,))
423447
424448 return self._account_data_id_gen.get_current_token()
425449
426 async def _update_max_stream_id(self, next_id: int) -> None:
427 """Update the max stream_id
428
429 Args:
430 next_id: The the revision to advance to.
431 """
432
433 # Note: This is only here for backwards compat to allow admins to
434 # roll back to a previous Synapse version. Next time we update the
435 # database version we can remove this table.
436
437 def _update(txn):
438 update_max_id_sql = (
439 "UPDATE account_data_max_stream_id"
440 " SET stream_id = ?"
441 " WHERE stream_id < ?"
442 )
443 txn.execute(update_max_id_sql, (next_id, next_id))
444
445 await self.db_pool.runInteraction("update_account_data_max_stream_id", _update)
450 def _add_account_data_for_user(
451 self,
452 txn,
453 next_id: int,
454 user_id: str,
455 account_data_type: str,
456 content: JsonDict,
457 ) -> None:
458 content_json = json_encoder.encode(content)
459
460 # no need to lock here as account_data has a unique constraint on
461 # (user_id, account_data_type) so simple_upsert will retry if
462 # there is a conflict.
463 self.db_pool.simple_upsert_txn(
464 txn,
465 table="account_data",
466 keyvalues={"user_id": user_id, "account_data_type": account_data_type},
467 values={"stream_id": next_id, "content": content_json},
468 lock=False,
469 )
470
471 # Ignored users get denormalized into a separate table as an optimisation.
472 if account_data_type != AccountDataTypes.IGNORED_USER_LIST:
473 return
474
475 # Insert / delete to sync the list of ignored users.
476 previously_ignored_users = set(
477 self.db_pool.simple_select_onecol_txn(
478 txn,
479 table="ignored_users",
480 keyvalues={"ignorer_user_id": user_id},
481 retcol="ignored_user_id",
482 )
483 )
484
485 # If the data is invalid, no one is ignored.
486 ignored_users_content = content.get("ignored_users", {})
487 if isinstance(ignored_users_content, dict):
488 currently_ignored_users = set(ignored_users_content)
489 else:
490 currently_ignored_users = set()
491
492 # Delete entries which are no longer ignored.
493 self.db_pool.simple_delete_many_txn(
494 txn,
495 table="ignored_users",
496 column="ignored_user_id",
497 iterable=previously_ignored_users - currently_ignored_users,
498 keyvalues={"ignorer_user_id": user_id},
499 )
500
501 # Add entries which are newly ignored.
502 self.db_pool.simple_insert_many_txn(
503 txn,
504 table="ignored_users",
505 values=[
506 {"ignorer_user_id": user_id, "ignored_user_id": u}
507 for u in currently_ignored_users - previously_ignored_users
508 ],
509 )
510
511 # Invalidate the cache for any ignored users which were added or removed.
512 for ignored_user_id in previously_ignored_users ^ currently_ignored_users:
513 self._invalidate_cache_and_stream(txn, self.ignored_by, (ignored_user_id,))
514
515
516 class AccountDataStore(AccountDataWorkerStore):
517 pass
406406 "_prune_old_user_ips", _prune_old_user_ips_txn
407407 )
408408
409 async def get_last_client_ip_by_device(
410 self, user_id: str, device_id: Optional[str]
411 ) -> Dict[Tuple[str, str], dict]:
412 """For each device_id listed, give the user_ip it was last seen on.
413
414 The result might be slightly out of date as client IPs are inserted in batches.
415
416 Args:
417 user_id: The user to fetch devices for.
418 device_id: If None fetches all devices for the user
419
420 Returns:
421 A dictionary mapping a tuple of (user_id, device_id) to dicts, with
422 keys giving the column names from the devices table.
423 """
424
425 keyvalues = {"user_id": user_id}
426 if device_id is not None:
427 keyvalues["device_id"] = device_id
428
429 res = await self.db_pool.simple_select_list(
430 table="devices",
431 keyvalues=keyvalues,
432 retcols=("user_id", "ip", "user_agent", "device_id", "last_seen"),
433 )
434
435 return {(d["user_id"], d["device_id"]): d for d in res}
436
409437
410438 class ClientIpStore(ClientIpWorkerStore):
411439 def __init__(self, database: DatabasePool, db_conn, hs):
469497 for entry in to_update.items():
470498 (user_id, access_token, ip), (user_agent, device_id, last_seen) = entry
471499
472 try:
473 self.db_pool.simple_upsert_txn(
500 self.db_pool.simple_upsert_txn(
501 txn,
502 table="user_ips",
503 keyvalues={"user_id": user_id, "access_token": access_token, "ip": ip},
504 values={
505 "user_agent": user_agent,
506 "device_id": device_id,
507 "last_seen": last_seen,
508 },
509 lock=False,
510 )
511
512 # Technically an access token might not be associated with
513 # a device so we need to check.
514 if device_id:
515 # this is always an update rather than an upsert: the row should
516 # already exist, and if it doesn't, that may be because it has been
517 # deleted, and we don't want to re-create it.
518 self.db_pool.simple_update_txn(
474519 txn,
475 table="user_ips",
476 keyvalues={
477 "user_id": user_id,
478 "access_token": access_token,
520 table="devices",
521 keyvalues={"user_id": user_id, "device_id": device_id},
522 updatevalues={
523 "user_agent": user_agent,
524 "last_seen": last_seen,
479525 "ip": ip,
480526 },
481 values={
482 "user_agent": user_agent,
483 "device_id": device_id,
484 "last_seen": last_seen,
485 },
486 lock=False,
487527 )
488
489 # Technically an access token might not be associated with
490 # a device so we need to check.
491 if device_id:
492 # this is always an update rather than an upsert: the row should
493 # already exist, and if it doesn't, that may be because it has been
494 # deleted, and we don't want to re-create it.
495 self.db_pool.simple_update_txn(
496 txn,
497 table="devices",
498 keyvalues={"user_id": user_id, "device_id": device_id},
499 updatevalues={
500 "user_agent": user_agent,
501 "last_seen": last_seen,
502 "ip": ip,
503 },
504 )
505 except Exception as e:
506 # Failed to upsert, log and continue
507 logger.error("Failed to insert client IP %r: %r", entry, e)
508528
509529 async def get_last_client_ip_by_device(
510530 self, user_id: str, device_id: Optional[str]
519539 A dictionary mapping a tuple of (user_id, device_id) to dicts, with
520540 keys giving the column names from the devices table.
521541 """
522
523 keyvalues = {"user_id": user_id}
524 if device_id is not None:
525 keyvalues["device_id"] = device_id
526
527 res = await self.db_pool.simple_select_list(
528 table="devices",
529 keyvalues=keyvalues,
530 retcols=("user_id", "ip", "user_agent", "device_id", "last_seen"),
531 )
532
533 ret = {(d["user_id"], d["device_id"]): d for d in res}
542 ret = await super().get_last_client_ip_by_device(user_id, device_id)
543
544 # Update what is retrieved from the database with data which is pending insertion.
534545 for key in self._batch_row_update:
535546 uid, access_token, ip = key
536547 if uid == user_id:
1616 from typing import List, Tuple
1717
1818 from synapse.logging.opentracing import log_kv, set_tag, trace
19 from synapse.storage._base import SQLBaseStore, db_to_json, make_in_list_sql_clause
19 from synapse.replication.tcp.streams import ToDeviceStream
20 from synapse.storage._base import SQLBaseStore, db_to_json
2021 from synapse.storage.database import DatabasePool
22 from synapse.storage.engines import PostgresEngine
23 from synapse.storage.util.id_generators import MultiWriterIdGenerator, StreamIdGenerator
2124 from synapse.util import json_encoder
2225 from synapse.util.caches.expiringcache import ExpiringCache
26 from synapse.util.caches.stream_change_cache import StreamChangeCache
2327
2428 logger = logging.getLogger(__name__)
2529
2630
2731 class DeviceInboxWorkerStore(SQLBaseStore):
32 def __init__(self, database: DatabasePool, db_conn, hs):
33 super().__init__(database, db_conn, hs)
34
35 self._instance_name = hs.get_instance_name()
36
37 # Map of (user_id, device_id) to the last stream_id that has been
38 # deleted up to. This is so that we can no op deletions.
39 self._last_device_delete_cache = ExpiringCache(
40 cache_name="last_device_delete_cache",
41 clock=self._clock,
42 max_len=10000,
43 expiry_ms=30 * 60 * 1000,
44 )
45
46 if isinstance(database.engine, PostgresEngine):
47 self._can_write_to_device = (
48 self._instance_name in hs.config.worker.writers.to_device
49 )
50
51 self._device_inbox_id_gen = MultiWriterIdGenerator(
52 db_conn=db_conn,
53 db=database,
54 stream_name="to_device",
55 instance_name=self._instance_name,
56 tables=[("device_inbox", "instance_name", "stream_id")],
57 sequence_name="device_inbox_sequence",
58 writers=hs.config.worker.writers.to_device,
59 )
60 else:
61 self._can_write_to_device = True
62 self._device_inbox_id_gen = StreamIdGenerator(
63 db_conn, "device_inbox", "stream_id"
64 )
65
66 max_device_inbox_id = self._device_inbox_id_gen.get_current_token()
67 device_inbox_prefill, min_device_inbox_id = self.db_pool.get_cache_dict(
68 db_conn,
69 "device_inbox",
70 entity_column="user_id",
71 stream_column="stream_id",
72 max_value=max_device_inbox_id,
73 limit=1000,
74 )
75 self._device_inbox_stream_cache = StreamChangeCache(
76 "DeviceInboxStreamChangeCache",
77 min_device_inbox_id,
78 prefilled_cache=device_inbox_prefill,
79 )
80
81 # The federation outbox and the local device inbox uses the same
82 # stream_id generator.
83 device_outbox_prefill, min_device_outbox_id = self.db_pool.get_cache_dict(
84 db_conn,
85 "device_federation_outbox",
86 entity_column="destination",
87 stream_column="stream_id",
88 max_value=max_device_inbox_id,
89 limit=1000,
90 )
91 self._device_federation_outbox_stream_cache = StreamChangeCache(
92 "DeviceFederationOutboxStreamChangeCache",
93 min_device_outbox_id,
94 prefilled_cache=device_outbox_prefill,
95 )
96
97 def process_replication_rows(self, stream_name, instance_name, token, rows):
98 if stream_name == ToDeviceStream.NAME:
99 self._device_inbox_id_gen.advance(instance_name, token)
100 for row in rows:
101 if row.entity.startswith("@"):
102 self._device_inbox_stream_cache.entity_has_changed(
103 row.entity, token
104 )
105 else:
106 self._device_federation_outbox_stream_cache.entity_has_changed(
107 row.entity, token
108 )
109 return super().process_replication_rows(stream_name, instance_name, token, rows)
110
28111 def get_to_device_stream_token(self):
29112 return self._device_inbox_id_gen.get_current_token()
30113
277360 "get_all_new_device_messages", get_all_new_device_messages_txn
278361 )
279362
280
281 class DeviceInboxBackgroundUpdateStore(SQLBaseStore):
282 DEVICE_INBOX_STREAM_ID = "device_inbox_stream_drop"
283
284 def __init__(self, database: DatabasePool, db_conn, hs):
285 super().__init__(database, db_conn, hs)
286
287 self.db_pool.updates.register_background_index_update(
288 "device_inbox_stream_index",
289 index_name="device_inbox_stream_id_user_id",
290 table="device_inbox",
291 columns=["stream_id", "user_id"],
292 )
293
294 self.db_pool.updates.register_background_update_handler(
295 self.DEVICE_INBOX_STREAM_ID, self._background_drop_index_device_inbox
296 )
297
298 async def _background_drop_index_device_inbox(self, progress, batch_size):
299 def reindex_txn(conn):
300 txn = conn.cursor()
301 txn.execute("DROP INDEX IF EXISTS device_inbox_stream_id")
302 txn.close()
303
304 await self.db_pool.runWithConnection(reindex_txn)
305
306 await self.db_pool.updates._end_background_update(self.DEVICE_INBOX_STREAM_ID)
307
308 return 1
309
310
311 class DeviceInboxStore(DeviceInboxWorkerStore, DeviceInboxBackgroundUpdateStore):
312 DEVICE_INBOX_STREAM_ID = "device_inbox_stream_drop"
313
314 def __init__(self, database: DatabasePool, db_conn, hs):
315 super().__init__(database, db_conn, hs)
316
317 # Map of (user_id, device_id) to the last stream_id that has been
318 # deleted up to. This is so that we can no op deletions.
319 self._last_device_delete_cache = ExpiringCache(
320 cache_name="last_device_delete_cache",
321 clock=self._clock,
322 max_len=10000,
323 expiry_ms=30 * 60 * 1000,
324 )
325
326363 @trace
327364 async def add_messages_to_device_inbox(
328365 self,
341378 The new stream_id.
342379 """
343380
381 assert self._can_write_to_device
382
344383 def add_messages_txn(txn, now_ms, stream_id):
345384 # Add the local messages directly to the local inbox.
346385 self._add_messages_to_local_device_inbox_txn(
350389 # Add the remote messages to the federation outbox.
351390 # We'll send them to a remote server when we next send a
352391 # federation transaction to that destination.
353 sql = (
354 "INSERT INTO device_federation_outbox"
355 " (destination, stream_id, queued_ts, messages_json)"
356 " VALUES (?,?,?,?)"
357 )
358 rows = []
359 for destination, edu in remote_messages_by_destination.items():
360 edu_json = json_encoder.encode(edu)
361 rows.append((destination, stream_id, now_ms, edu_json))
362 txn.executemany(sql, rows)
392 self.db_pool.simple_insert_many_txn(
393 txn,
394 table="device_federation_outbox",
395 values=[
396 {
397 "destination": destination,
398 "stream_id": stream_id,
399 "queued_ts": now_ms,
400 "messages_json": json_encoder.encode(edu),
401 "instance_name": self._instance_name,
402 }
403 for destination, edu in remote_messages_by_destination.items()
404 ],
405 )
363406
364407 async with self._device_inbox_id_gen.get_next() as stream_id:
365408 now_ms = self.clock.time_msec()
378421 async def add_messages_from_remote_to_device_inbox(
379422 self, origin: str, message_id: str, local_messages_by_user_then_device: dict
380423 ) -> int:
424 assert self._can_write_to_device
425
381426 def add_messages_txn(txn, now_ms, stream_id):
382427 # Check if we've already inserted a matching message_id for that
383428 # origin. This can happen if the origin doesn't receive our
426471 def _add_messages_to_local_device_inbox_txn(
427472 self, txn, stream_id, messages_by_user_then_device
428473 ):
474 assert self._can_write_to_device
475
429476 local_by_user_then_device = {}
430477 for user_id, messages_by_device in messages_by_user_then_device.items():
431478 messages_json_for_user = {}
432479 devices = list(messages_by_device.keys())
433480 if len(devices) == 1 and devices[0] == "*":
434481 # Handle wildcard device_ids.
435 sql = "SELECT device_id FROM devices WHERE user_id = ?"
436 txn.execute(sql, (user_id,))
482 devices = self.db_pool.simple_select_onecol_txn(
483 txn,
484 table="devices",
485 keyvalues={"user_id": user_id},
486 retcol="device_id",
487 )
488
437489 message_json = json_encoder.encode(messages_by_device["*"])
438 for row in txn:
490 for device_id in devices:
439491 # Add the message for all devices for this user on this
440492 # server.
441 device = row[0]
442 messages_json_for_user[device] = message_json
493 messages_json_for_user[device_id] = message_json
443494 else:
444495 if not devices:
445496 continue
446497
447 clause, args = make_in_list_sql_clause(
448 txn.database_engine, "device_id", devices
498 rows = self.db_pool.simple_select_many_txn(
499 txn,
500 table="devices",
501 keyvalues={"user_id": user_id},
502 column="device_id",
503 iterable=devices,
504 retcols=("device_id",),
449505 )
450 sql = "SELECT device_id FROM devices WHERE user_id = ? AND " + clause
451
452 # TODO: Maybe this needs to be done in batches if there are
453 # too many local devices for a given user.
454 txn.execute(sql, [user_id] + list(args))
455 for row in txn:
506
507 for row in rows:
456508 # Only insert into the local inbox if the device exists on
457509 # this server
458 device = row[0]
459 message_json = json_encoder.encode(messages_by_device[device])
460 messages_json_for_user[device] = message_json
510 device_id = row["device_id"]
511 message_json = json_encoder.encode(messages_by_device[device_id])
512 messages_json_for_user[device_id] = message_json
461513
462514 if messages_json_for_user:
463515 local_by_user_then_device[user_id] = messages_json_for_user
465517 if not local_by_user_then_device:
466518 return
467519
468 sql = (
469 "INSERT INTO device_inbox"
470 " (user_id, device_id, stream_id, message_json)"
471 " VALUES (?,?,?,?)"
472 )
473 rows = []
474 for user_id, messages_by_device in local_by_user_then_device.items():
475 for device_id, message_json in messages_by_device.items():
476 rows.append((user_id, device_id, stream_id, message_json))
477
478 txn.executemany(sql, rows)
520 self.db_pool.simple_insert_many_txn(
521 txn,
522 table="device_inbox",
523 values=[
524 {
525 "user_id": user_id,
526 "device_id": device_id,
527 "stream_id": stream_id,
528 "message_json": message_json,
529 "instance_name": self._instance_name,
530 }
531 for user_id, messages_by_device in local_by_user_then_device.items()
532 for device_id, message_json in messages_by_device.items()
533 ],
534 )
535
536
537 class DeviceInboxBackgroundUpdateStore(SQLBaseStore):
538 DEVICE_INBOX_STREAM_ID = "device_inbox_stream_drop"
539
540 def __init__(self, database: DatabasePool, db_conn, hs):
541 super().__init__(database, db_conn, hs)
542
543 self.db_pool.updates.register_background_index_update(
544 "device_inbox_stream_index",
545 index_name="device_inbox_stream_id_user_id",
546 table="device_inbox",
547 columns=["stream_id", "user_id"],
548 )
549
550 self.db_pool.updates.register_background_update_handler(
551 self.DEVICE_INBOX_STREAM_ID, self._background_drop_index_device_inbox
552 )
553
554 async def _background_drop_index_device_inbox(self, progress, batch_size):
555 def reindex_txn(conn):
556 txn = conn.cursor()
557 txn.execute("DROP INDEX IF EXISTS device_inbox_stream_id")
558 txn.close()
559
560 await self.db_pool.runWithConnection(reindex_txn)
561
562 await self.db_pool.updates._end_background_update(self.DEVICE_INBOX_STREAM_ID)
563
564 return 1
565
566
567 class DeviceInboxStore(DeviceInboxWorkerStore, DeviceInboxBackgroundUpdateStore):
568 pass
2424 from synapse.logging.opentracing import log_kv, set_tag, trace
2525 from synapse.storage._base import SQLBaseStore, db_to_json
2626 from synapse.storage.database import DatabasePool, make_in_list_sql_clause
27 from synapse.storage.engines import PostgresEngine
2728 from synapse.storage.types import Cursor
2829 from synapse.types import JsonDict
2930 from synapse.util import json_encoder
512513
513514 for user_chunk in batch_iter(user_ids, 100):
514515 clause, params = make_in_list_sql_clause(
515 txn.database_engine, "k.user_id", user_chunk
516 )
517 sql = (
518 """
519 SELECT k.user_id, k.keytype, k.keydata, k.stream_id
520 FROM e2e_cross_signing_keys k
521 INNER JOIN (SELECT user_id, keytype, MAX(stream_id) AS stream_id
522 FROM e2e_cross_signing_keys
523 GROUP BY user_id, keytype) s
524 USING (user_id, stream_id, keytype)
525 WHERE
526 """
527 + clause
528 )
516 txn.database_engine, "user_id", user_chunk
517 )
518
519 # Fetch the latest key for each type per user.
520 if isinstance(self.database_engine, PostgresEngine):
521 # The `DISTINCT ON` clause will pick the *first* row it
522 # encounters, so ordering by stream ID desc will ensure we get
523 # the latest key.
524 sql = """
525 SELECT DISTINCT ON (user_id, keytype) user_id, keytype, keydata, stream_id
526 FROM e2e_cross_signing_keys
527 WHERE %(clause)s
528 ORDER BY user_id, keytype, stream_id DESC
529 """ % {
530 "clause": clause
531 }
532 else:
533 # SQLite has special handling for bare columns when using
534 # MIN/MAX with a `GROUP BY` clause where it picks the value from
535 # a row that matches the MIN/MAX.
536 sql = """
537 SELECT user_id, keytype, keydata, MAX(stream_id)
538 FROM e2e_cross_signing_keys
539 WHERE %(clause)s
540 GROUP BY user_id, keytype
541 """ % {
542 "clause": clause
543 }
529544
530545 txn.execute(sql, params)
531546 rows = self.db_pool.cursor_to_dict(txn)
705720 def get_device_stream_token(self) -> int:
706721 """Get the current stream id from the _device_list_id_gen"""
707722 ...
708
709
710 class EndToEndKeyStore(EndToEndKeyWorkerStore, SQLBaseStore):
711 async def set_e2e_device_keys(
712 self, user_id: str, device_id: str, time_now: int, device_keys: JsonDict
713 ) -> bool:
714 """Stores device keys for a device. Returns whether there was a change
715 or the keys were already in the database.
716 """
717
718 def _set_e2e_device_keys_txn(txn):
719 set_tag("user_id", user_id)
720 set_tag("device_id", device_id)
721 set_tag("time_now", time_now)
722 set_tag("device_keys", device_keys)
723
724 old_key_json = self.db_pool.simple_select_one_onecol_txn(
725 txn,
726 table="e2e_device_keys_json",
727 keyvalues={"user_id": user_id, "device_id": device_id},
728 retcol="key_json",
729 allow_none=True,
730 )
731
732 # In py3 we need old_key_json to match new_key_json type. The DB
733 # returns unicode while encode_canonical_json returns bytes.
734 new_key_json = encode_canonical_json(device_keys).decode("utf-8")
735
736 if old_key_json == new_key_json:
737 log_kv({"Message": "Device key already stored."})
738 return False
739
740 self.db_pool.simple_upsert_txn(
741 txn,
742 table="e2e_device_keys_json",
743 keyvalues={"user_id": user_id, "device_id": device_id},
744 values={"ts_added_ms": time_now, "key_json": new_key_json},
745 )
746 log_kv({"message": "Device keys stored."})
747 return True
748
749 return await self.db_pool.runInteraction(
750 "set_e2e_device_keys", _set_e2e_device_keys_txn
751 )
752723
753724 async def claim_e2e_one_time_keys(
754725 self, query_list: Iterable[Tuple[str, str, str]]
839810 "claim_e2e_one_time_keys", _claim_e2e_one_time_keys
840811 )
841812
813
814 class EndToEndKeyStore(EndToEndKeyWorkerStore, SQLBaseStore):
815 async def set_e2e_device_keys(
816 self, user_id: str, device_id: str, time_now: int, device_keys: JsonDict
817 ) -> bool:
818 """Stores device keys for a device. Returns whether there was a change
819 or the keys were already in the database.
820 """
821
822 def _set_e2e_device_keys_txn(txn):
823 set_tag("user_id", user_id)
824 set_tag("device_id", device_id)
825 set_tag("time_now", time_now)
826 set_tag("device_keys", device_keys)
827
828 old_key_json = self.db_pool.simple_select_one_onecol_txn(
829 txn,
830 table="e2e_device_keys_json",
831 keyvalues={"user_id": user_id, "device_id": device_id},
832 retcol="key_json",
833 allow_none=True,
834 )
835
836 # In py3 we need old_key_json to match new_key_json type. The DB
837 # returns unicode while encode_canonical_json returns bytes.
838 new_key_json = encode_canonical_json(device_keys).decode("utf-8")
839
840 if old_key_json == new_key_json:
841 log_kv({"Message": "Device key already stored."})
842 return False
843
844 self.db_pool.simple_upsert_txn(
845 txn,
846 table="e2e_device_keys_json",
847 keyvalues={"user_id": user_id, "device_id": device_id},
848 values={"ts_added_ms": time_now, "key_json": new_key_json},
849 )
850 log_kv({"message": "Device keys stored."})
851 return True
852
853 return await self.db_pool.runInteraction(
854 "set_e2e_device_keys", _set_e2e_device_keys_txn
855 )
856
842857 async def delete_e2e_keys_by_device(self, user_id: str, device_id: str) -> None:
843858 def delete_e2e_keys_by_device_txn(txn):
844859 log_kv(
2323 from synapse.storage.database import DatabasePool, LoggingTransaction
2424 from synapse.storage.databases.main.events_worker import EventsWorkerStore
2525 from synapse.storage.databases.main.signatures import SignatureWorkerStore
26 from synapse.storage.engines import PostgresEngine
27 from synapse.storage.types import Cursor
2628 from synapse.types import Collection
2729 from synapse.util.caches.descriptors import cached
2830 from synapse.util.caches.lrucache import LruCache
2931 from synapse.util.iterutils import batch_iter
3032
3133 logger = logging.getLogger(__name__)
34
35
36 class _NoChainCoverIndex(Exception):
37 def __init__(self, room_id: str):
38 super().__init__("Unexpectedly no chain cover for events in %s" % (room_id,))
3239
3340
3441 class EventFederationWorkerStore(EventsWorkerStore, SignatureWorkerStore, SQLBaseStore):
150157 The set of the difference in auth chains.
151158 """
152159
160 # Check if we have indexed the room so we can use the chain cover
161 # algorithm.
162 room = await self.get_room(room_id)
163 if room["has_auth_chain_index"]:
164 try:
165 return await self.db_pool.runInteraction(
166 "get_auth_chain_difference_chains",
167 self._get_auth_chain_difference_using_cover_index_txn,
168 room_id,
169 state_sets,
170 )
171 except _NoChainCoverIndex:
172 # For whatever reason we don't actually have a chain cover index
173 # for the events in question, so we fall back to the old method.
174 pass
175
153176 return await self.db_pool.runInteraction(
154177 "get_auth_chain_difference",
155178 self._get_auth_chain_difference_txn,
156179 state_sets,
157180 )
158181
182 def _get_auth_chain_difference_using_cover_index_txn(
183 self, txn: Cursor, room_id: str, state_sets: List[Set[str]]
184 ) -> Set[str]:
185 """Calculates the auth chain difference using the chain index.
186
187 See docs/auth_chain_difference_algorithm.md for details
188 """
189
190 # First we look up the chain ID/sequence numbers for all the events, and
191 # work out the chain/sequence numbers reachable from each state set.
192
193 initial_events = set(state_sets[0]).union(*state_sets[1:])
194
195 # Map from event_id -> (chain ID, seq no)
196 chain_info = {} # type: Dict[str, Tuple[int, int]]
197
198 # Map from chain ID -> seq no -> event Id
199 chain_to_event = {} # type: Dict[int, Dict[int, str]]
200
201 # All the chains that we've found that are reachable from the state
202 # sets.
203 seen_chains = set() # type: Set[int]
204
205 sql = """
206 SELECT event_id, chain_id, sequence_number
207 FROM event_auth_chains
208 WHERE %s
209 """
210 for batch in batch_iter(initial_events, 1000):
211 clause, args = make_in_list_sql_clause(
212 txn.database_engine, "event_id", batch
213 )
214 txn.execute(sql % (clause,), args)
215
216 for event_id, chain_id, sequence_number in txn:
217 chain_info[event_id] = (chain_id, sequence_number)
218 seen_chains.add(chain_id)
219 chain_to_event.setdefault(chain_id, {})[sequence_number] = event_id
220
221 # Check that we actually have a chain ID for all the events.
222 events_missing_chain_info = initial_events.difference(chain_info)
223 if events_missing_chain_info:
224 # This can happen due to e.g. downgrade/upgrade of the server. We
225 # raise an exception and fall back to the previous algorithm.
226 logger.info(
227 "Unexpectedly found that events don't have chain IDs in room %s: %s",
228 room_id,
229 events_missing_chain_info,
230 )
231 raise _NoChainCoverIndex(room_id)
232
233 # Corresponds to `state_sets`, except as a map from chain ID to max
234 # sequence number reachable from the state set.
235 set_to_chain = [] # type: List[Dict[int, int]]
236 for state_set in state_sets:
237 chains = {} # type: Dict[int, int]
238 set_to_chain.append(chains)
239
240 for event_id in state_set:
241 chain_id, seq_no = chain_info[event_id]
242
243 chains[chain_id] = max(seq_no, chains.get(chain_id, 0))
244
245 # Now we look up all links for the chains we have, adding chains to
246 # set_to_chain that are reachable from each set.
247 sql = """
248 SELECT
249 origin_chain_id, origin_sequence_number,
250 target_chain_id, target_sequence_number
251 FROM event_auth_chain_links
252 WHERE %s
253 """
254
255 # (We need to take a copy of `seen_chains` as we want to mutate it in
256 # the loop)
257 for batch in batch_iter(set(seen_chains), 1000):
258 clause, args = make_in_list_sql_clause(
259 txn.database_engine, "origin_chain_id", batch
260 )
261 txn.execute(sql % (clause,), args)
262
263 for (
264 origin_chain_id,
265 origin_sequence_number,
266 target_chain_id,
267 target_sequence_number,
268 ) in txn:
269 for chains in set_to_chain:
270 # chains are only reachable if the origin sequence number of
271 # the link is less than the max sequence number in the
272 # origin chain.
273 if origin_sequence_number <= chains.get(origin_chain_id, 0):
274 chains[target_chain_id] = max(
275 target_sequence_number, chains.get(target_chain_id, 0),
276 )
277
278 seen_chains.add(target_chain_id)
279
280 # Now for each chain we figure out the maximum sequence number reachable
281 # from *any* state set and the minimum sequence number reachable from
282 # *all* state sets. Events in that range are in the auth chain
283 # difference.
284 result = set()
285
286 # Mapping from chain ID to the range of sequence numbers that should be
287 # pulled from the database.
288 chain_to_gap = {} # type: Dict[int, Tuple[int, int]]
289
290 for chain_id in seen_chains:
291 min_seq_no = min(chains.get(chain_id, 0) for chains in set_to_chain)
292 max_seq_no = max(chains.get(chain_id, 0) for chains in set_to_chain)
293
294 if min_seq_no < max_seq_no:
295 # We have a non empty gap, try and fill it from the events that
296 # we have, otherwise add them to the list of gaps to pull out
297 # from the DB.
298 for seq_no in range(min_seq_no + 1, max_seq_no + 1):
299 event_id = chain_to_event.get(chain_id, {}).get(seq_no)
300 if event_id:
301 result.add(event_id)
302 else:
303 chain_to_gap[chain_id] = (min_seq_no, max_seq_no)
304 break
305
306 if not chain_to_gap:
307 # If there are no gaps to fetch, we're done!
308 return result
309
310 if isinstance(self.database_engine, PostgresEngine):
311 # We can use `execute_values` to efficiently fetch the gaps when
312 # using postgres.
313 sql = """
314 SELECT event_id
315 FROM event_auth_chains AS c, (VALUES ?) AS l(chain_id, min_seq, max_seq)
316 WHERE
317 c.chain_id = l.chain_id
318 AND min_seq < sequence_number AND sequence_number <= max_seq
319 """
320
321 args = [
322 (chain_id, min_no, max_no)
323 for chain_id, (min_no, max_no) in chain_to_gap.items()
324 ]
325
326 rows = txn.execute_values(sql, args)
327 result.update(r for r, in rows)
328 else:
329 # For SQLite we just fall back to doing a noddy for loop.
330 sql = """
331 SELECT event_id FROM event_auth_chains
332 WHERE chain_id = ? AND ? < sequence_number AND sequence_number <= ?
333 """
334 for chain_id, (min_no, max_no) in chain_to_gap.items():
335 txn.execute(sql, (chain_id, min_no, max_no))
336 result.update(r for r, in txn)
337
338 return result
339
159340 def _get_auth_chain_difference_txn(
160341 self, txn, state_sets: List[Set[str]]
161342 ) -> Set[str]:
343 """Calculates the auth chain difference using a breadth first search.
344
345 This is used when we don't have a cover index for the room.
346 """
162347
163348 # Algorithm Description
164349 # ~~~~~~~~~~~~~~~~~~~~~
834834 (rotate_to_stream_ordering,),
835835 )
836836
837
838 class EventPushActionsStore(EventPushActionsWorkerStore):
839 EPA_HIGHLIGHT_INDEX = "epa_highlight_index"
840
841 def __init__(self, database: DatabasePool, db_conn, hs):
842 super().__init__(database, db_conn, hs)
843
844 self.db_pool.updates.register_background_index_update(
845 self.EPA_HIGHLIGHT_INDEX,
846 index_name="event_push_actions_u_highlight",
847 table="event_push_actions",
848 columns=["user_id", "stream_ordering"],
849 )
850
851 self.db_pool.updates.register_background_index_update(
852 "event_push_actions_highlights_index",
853 index_name="event_push_actions_highlights_index",
854 table="event_push_actions",
855 columns=["user_id", "room_id", "topological_ordering", "stream_ordering"],
856 where_clause="highlight=1",
857 )
858
859 async def get_push_actions_for_user(
860 self, user_id, before=None, limit=50, only_highlight=False
861 ):
862 def f(txn):
863 before_clause = ""
864 if before:
865 before_clause = "AND epa.stream_ordering < ?"
866 args = [user_id, before, limit]
867 else:
868 args = [user_id, limit]
869
870 if only_highlight:
871 if len(before_clause) > 0:
872 before_clause += " "
873 before_clause += "AND epa.highlight = 1"
874
875 # NB. This assumes event_ids are globally unique since
876 # it makes the query easier to index
877 sql = (
878 "SELECT epa.event_id, epa.room_id,"
879 " epa.stream_ordering, epa.topological_ordering,"
880 " epa.actions, epa.highlight, epa.profile_tag, e.received_ts"
881 " FROM event_push_actions epa, events e"
882 " WHERE epa.event_id = e.event_id"
883 " AND epa.user_id = ? %s"
884 " AND epa.notif = 1"
885 " ORDER BY epa.stream_ordering DESC"
886 " LIMIT ?" % (before_clause,)
887 )
888 txn.execute(sql, args)
889 return self.db_pool.cursor_to_dict(txn)
890
891 push_actions = await self.db_pool.runInteraction("get_push_actions_for_user", f)
892 for pa in push_actions:
893 pa["actions"] = _deserialize_action(pa["actions"], pa["highlight"])
894 return push_actions
895
896837 def _remove_old_push_actions_before_txn(
897838 self, txn, room_id, user_id, stream_ordering
898839 ):
940881 )
941882
942883
884 class EventPushActionsStore(EventPushActionsWorkerStore):
885 EPA_HIGHLIGHT_INDEX = "epa_highlight_index"
886
887 def __init__(self, database: DatabasePool, db_conn, hs):
888 super().__init__(database, db_conn, hs)
889
890 self.db_pool.updates.register_background_index_update(
891 self.EPA_HIGHLIGHT_INDEX,
892 index_name="event_push_actions_u_highlight",
893 table="event_push_actions",
894 columns=["user_id", "stream_ordering"],
895 )
896
897 self.db_pool.updates.register_background_index_update(
898 "event_push_actions_highlights_index",
899 index_name="event_push_actions_highlights_index",
900 table="event_push_actions",
901 columns=["user_id", "room_id", "topological_ordering", "stream_ordering"],
902 where_clause="highlight=1",
903 )
904
905 async def get_push_actions_for_user(
906 self, user_id, before=None, limit=50, only_highlight=False
907 ):
908 def f(txn):
909 before_clause = ""
910 if before:
911 before_clause = "AND epa.stream_ordering < ?"
912 args = [user_id, before, limit]
913 else:
914 args = [user_id, limit]
915
916 if only_highlight:
917 if len(before_clause) > 0:
918 before_clause += " "
919 before_clause += "AND epa.highlight = 1"
920
921 # NB. This assumes event_ids are globally unique since
922 # it makes the query easier to index
923 sql = (
924 "SELECT epa.event_id, epa.room_id,"
925 " epa.stream_ordering, epa.topological_ordering,"
926 " epa.actions, epa.highlight, epa.profile_tag, e.received_ts"
927 " FROM event_push_actions epa, events e"
928 " WHERE epa.event_id = e.event_id"
929 " AND epa.user_id = ? %s"
930 " AND epa.notif = 1"
931 " ORDER BY epa.stream_ordering DESC"
932 " LIMIT ?" % (before_clause,)
933 )
934 txn.execute(sql, args)
935 return self.db_pool.cursor_to_dict(txn)
936
937 push_actions = await self.db_pool.runInteraction("get_push_actions_for_user", f)
938 for pa in push_actions:
939 pa["actions"] = _deserialize_action(pa["actions"], pa["highlight"])
940 return push_actions
941
942
943943 def _action_has_highlight(actions):
944944 for action in actions:
945945 try:
1616 import itertools
1717 import logging
1818 from collections import OrderedDict, namedtuple
19 from typing import TYPE_CHECKING, Any, Dict, Iterable, List, Optional, Set, Tuple
19 from typing import (
20 TYPE_CHECKING,
21 Any,
22 Dict,
23 Generator,
24 Iterable,
25 List,
26 Optional,
27 Set,
28 Tuple,
29 )
2030
2131 import attr
2232 from prometheus_client import Counter
3444 from synapse.storage.util.id_generators import MultiWriterIdGenerator
3545 from synapse.types import StateMap, get_domain_from_id
3646 from synapse.util import json_encoder
37 from synapse.util.iterutils import batch_iter
47 from synapse.util.iterutils import batch_iter, sorted_topologically
3848
3949 if TYPE_CHECKING:
4050 from synapse.server import HomeServer
365375 # Insert into event_to_state_groups.
366376 self._store_event_state_mappings_txn(txn, events_and_contexts)
367377
378 self._persist_event_auth_chain_txn(txn, [e for e, _ in events_and_contexts])
379
380 # _store_rejected_events_txn filters out any events which were
381 # rejected, and returns the filtered list.
382 events_and_contexts = self._store_rejected_events_txn(
383 txn, events_and_contexts=events_and_contexts
384 )
385
386 # From this point onwards the events are only ones that weren't
387 # rejected.
388
389 self._update_metadata_tables_txn(
390 txn,
391 events_and_contexts=events_and_contexts,
392 all_events_and_contexts=all_events_and_contexts,
393 backfilled=backfilled,
394 )
395
396 # We call this last as it assumes we've inserted the events into
397 # room_memberships, where applicable.
398 self._update_current_state_txn(txn, state_delta_for_room, min_stream_order)
399
400 def _persist_event_auth_chain_txn(
401 self, txn: LoggingTransaction, events: List[EventBase],
402 ) -> None:
403
404 # We only care about state events, so this if there are no state events.
405 if not any(e.is_state() for e in events):
406 return
407
368408 # We want to store event_auth mappings for rejected events, as they're
369409 # used in state res v2.
370410 # This is only necessary if the rejected event appears in an accepted
380420 "room_id": event.room_id,
381421 "auth_id": auth_id,
382422 }
383 for event, _ in events_and_contexts
423 for event in events
384424 for auth_id in event.auth_event_ids()
385425 if event.is_state()
386426 ],
387427 )
388428
389 # _store_rejected_events_txn filters out any events which were
390 # rejected, and returns the filtered list.
391 events_and_contexts = self._store_rejected_events_txn(
392 txn, events_and_contexts=events_and_contexts
393 )
394
395 # From this point onwards the events are only ones that weren't
396 # rejected.
397
398 self._update_metadata_tables_txn(
399 txn,
400 events_and_contexts=events_and_contexts,
401 all_events_and_contexts=all_events_and_contexts,
402 backfilled=backfilled,
403 )
404
405 # We call this last as it assumes we've inserted the events into
406 # room_memberships, where applicable.
407 self._update_current_state_txn(txn, state_delta_for_room, min_stream_order)
429 # We now calculate chain ID/sequence numbers for any state events we're
430 # persisting. We ignore out of band memberships as we're not in the room
431 # and won't have their auth chain (we'll fix it up later if we join the
432 # room).
433 #
434 # See: docs/auth_chain_difference_algorithm.md
435
436 # We ignore legacy rooms that we aren't filling the chain cover index
437 # for.
438 rows = self.db_pool.simple_select_many_txn(
439 txn,
440 table="rooms",
441 column="room_id",
442 iterable={event.room_id for event in events if event.is_state()},
443 keyvalues={},
444 retcols=("room_id", "has_auth_chain_index"),
445 )
446 rooms_using_chain_index = {
447 row["room_id"] for row in rows if row["has_auth_chain_index"]
448 }
449
450 state_events = {
451 event.event_id: event
452 for event in events
453 if event.is_state() and event.room_id in rooms_using_chain_index
454 }
455
456 if not state_events:
457 return
458
459 # We need to know the type/state_key and auth events of the events we're
460 # calculating chain IDs for. We don't rely on having the full Event
461 # instances as we'll potentially be pulling more events from the DB and
462 # we don't need the overhead of fetching/parsing the full event JSON.
463 event_to_types = {
464 e.event_id: (e.type, e.state_key) for e in state_events.values()
465 }
466 event_to_auth_chain = {
467 e.event_id: e.auth_event_ids() for e in state_events.values()
468 }
469 event_to_room_id = {e.event_id: e.room_id for e in state_events.values()}
470
471 self._add_chain_cover_index(
472 txn, self.db_pool, event_to_room_id, event_to_types, event_to_auth_chain,
473 )
474
475 @staticmethod
476 def _add_chain_cover_index(
477 txn,
478 db_pool: DatabasePool,
479 event_to_room_id: Dict[str, str],
480 event_to_types: Dict[str, Tuple[str, str]],
481 event_to_auth_chain: Dict[str, List[str]],
482 ) -> None:
483 """Calculate the chain cover index for the given events.
484
485 Args:
486 event_to_room_id: Event ID to the room ID of the event
487 event_to_types: Event ID to type and state_key of the event
488 event_to_auth_chain: Event ID to list of auth event IDs of the
489 event (events with no auth events can be excluded).
490 """
491
492 # Map from event ID to chain ID/sequence number.
493 chain_map = {} # type: Dict[str, Tuple[int, int]]
494
495 # Set of event IDs to calculate chain ID/seq numbers for.
496 events_to_calc_chain_id_for = set(event_to_room_id)
497
498 # We check if there are any events that need to be handled in the rooms
499 # we're looking at. These should just be out of band memberships, where
500 # we didn't have the auth chain when we first persisted.
501 rows = db_pool.simple_select_many_txn(
502 txn,
503 table="event_auth_chain_to_calculate",
504 keyvalues={},
505 column="room_id",
506 iterable=set(event_to_room_id.values()),
507 retcols=("event_id", "type", "state_key"),
508 )
509 for row in rows:
510 event_id = row["event_id"]
511 event_type = row["type"]
512 state_key = row["state_key"]
513
514 # (We could pull out the auth events for all rows at once using
515 # simple_select_many, but this case happens rarely and almost always
516 # with a single row.)
517 auth_events = db_pool.simple_select_onecol_txn(
518 txn, "event_auth", keyvalues={"event_id": event_id}, retcol="auth_id",
519 )
520
521 events_to_calc_chain_id_for.add(event_id)
522 event_to_types[event_id] = (event_type, state_key)
523 event_to_auth_chain[event_id] = auth_events
524
525 # First we get the chain ID and sequence numbers for the events'
526 # auth events (that aren't also currently being persisted).
527 #
528 # Note that there there is an edge case here where we might not have
529 # calculated chains and sequence numbers for events that were "out
530 # of band". We handle this case by fetching the necessary info and
531 # adding it to the set of events to calculate chain IDs for.
532
533 missing_auth_chains = {
534 a_id
535 for auth_events in event_to_auth_chain.values()
536 for a_id in auth_events
537 if a_id not in events_to_calc_chain_id_for
538 }
539
540 # We loop here in case we find an out of band membership and need to
541 # fetch their auth event info.
542 while missing_auth_chains:
543 sql = """
544 SELECT event_id, events.type, state_key, chain_id, sequence_number
545 FROM events
546 INNER JOIN state_events USING (event_id)
547 LEFT JOIN event_auth_chains USING (event_id)
548 WHERE
549 """
550 clause, args = make_in_list_sql_clause(
551 txn.database_engine, "event_id", missing_auth_chains,
552 )
553 txn.execute(sql + clause, args)
554
555 missing_auth_chains.clear()
556
557 for auth_id, event_type, state_key, chain_id, sequence_number in txn:
558 event_to_types[auth_id] = (event_type, state_key)
559
560 if chain_id is None:
561 # No chain ID, so the event was persisted out of band.
562 # We add to list of events to calculate auth chains for.
563
564 events_to_calc_chain_id_for.add(auth_id)
565
566 event_to_auth_chain[auth_id] = db_pool.simple_select_onecol_txn(
567 txn,
568 "event_auth",
569 keyvalues={"event_id": auth_id},
570 retcol="auth_id",
571 )
572
573 missing_auth_chains.update(
574 e
575 for e in event_to_auth_chain[auth_id]
576 if e not in event_to_types
577 )
578 else:
579 chain_map[auth_id] = (chain_id, sequence_number)
580
581 # Now we check if we have any events where we don't have auth chain,
582 # this should only be out of band memberships.
583 for event_id in sorted_topologically(event_to_auth_chain, event_to_auth_chain):
584 for auth_id in event_to_auth_chain[event_id]:
585 if (
586 auth_id not in chain_map
587 and auth_id not in events_to_calc_chain_id_for
588 ):
589 events_to_calc_chain_id_for.discard(event_id)
590
591 # If this is an event we're trying to persist we add it to
592 # the list of events to calculate chain IDs for next time
593 # around. (Otherwise we will have already added it to the
594 # table).
595 room_id = event_to_room_id.get(event_id)
596 if room_id:
597 e_type, state_key = event_to_types[event_id]
598 db_pool.simple_insert_txn(
599 txn,
600 table="event_auth_chain_to_calculate",
601 values={
602 "event_id": event_id,
603 "room_id": room_id,
604 "type": e_type,
605 "state_key": state_key,
606 },
607 )
608
609 # We stop checking the event's auth events since we've
610 # discarded it.
611 break
612
613 if not events_to_calc_chain_id_for:
614 return
615
616 # We now calculate the chain IDs/sequence numbers for the events. We
617 # do this by looking at the chain ID and sequence number of any auth
618 # event with the same type/state_key and incrementing the sequence
619 # number by one. If there was no match or the chain ID/sequence
620 # number is already taken we generate a new chain.
621 #
622 # We need to do this in a topologically sorted order as we want to
623 # generate chain IDs/sequence numbers of an event's auth events
624 # before the event itself.
625 chains_tuples_allocated = set() # type: Set[Tuple[int, int]]
626 new_chain_tuples = {} # type: Dict[str, Tuple[int, int]]
627 for event_id in sorted_topologically(
628 events_to_calc_chain_id_for, event_to_auth_chain
629 ):
630 existing_chain_id = None
631 for auth_id in event_to_auth_chain.get(event_id, []):
632 if event_to_types.get(event_id) == event_to_types.get(auth_id):
633 existing_chain_id = chain_map[auth_id]
634 break
635
636 new_chain_tuple = None
637 if existing_chain_id:
638 # We found a chain ID/sequence number candidate, check its
639 # not already taken.
640 proposed_new_id = existing_chain_id[0]
641 proposed_new_seq = existing_chain_id[1] + 1
642 if (proposed_new_id, proposed_new_seq) not in chains_tuples_allocated:
643 already_allocated = db_pool.simple_select_one_onecol_txn(
644 txn,
645 table="event_auth_chains",
646 keyvalues={
647 "chain_id": proposed_new_id,
648 "sequence_number": proposed_new_seq,
649 },
650 retcol="event_id",
651 allow_none=True,
652 )
653 if already_allocated:
654 # Mark it as already allocated so we don't need to hit
655 # the DB again.
656 chains_tuples_allocated.add((proposed_new_id, proposed_new_seq))
657 else:
658 new_chain_tuple = (
659 proposed_new_id,
660 proposed_new_seq,
661 )
662
663 if not new_chain_tuple:
664 new_chain_tuple = (db_pool.event_chain_id_gen.get_next_id_txn(txn), 1)
665
666 chains_tuples_allocated.add(new_chain_tuple)
667
668 chain_map[event_id] = new_chain_tuple
669 new_chain_tuples[event_id] = new_chain_tuple
670
671 db_pool.simple_insert_many_txn(
672 txn,
673 table="event_auth_chains",
674 values=[
675 {"event_id": event_id, "chain_id": c_id, "sequence_number": seq}
676 for event_id, (c_id, seq) in new_chain_tuples.items()
677 ],
678 )
679
680 db_pool.simple_delete_many_txn(
681 txn,
682 table="event_auth_chain_to_calculate",
683 keyvalues={},
684 column="event_id",
685 iterable=new_chain_tuples,
686 )
687
688 # Now we need to calculate any new links between chains caused by
689 # the new events.
690 #
691 # Links are pairs of chain ID/sequence numbers such that for any
692 # event A (CA, SA) and any event B (CB, SB), B is in A's auth chain
693 # if and only if there is at least one link (CA, S1) -> (CB, S2)
694 # where SA >= S1 and S2 >= SB.
695 #
696 # We try and avoid adding redundant links to the table, e.g. if we
697 # have two links between two chains which both start/end at the
698 # sequence number event (or cross) then one can be safely dropped.
699 #
700 # To calculate new links we look at every new event and:
701 # 1. Fetch the chain ID/sequence numbers of its auth events,
702 # discarding any that are reachable by other auth events, or
703 # that have the same chain ID as the event.
704 # 2. For each retained auth event we:
705 # a. Add a link from the event's to the auth event's chain
706 # ID/sequence number; and
707 # b. Add a link from the event to every chain reachable by the
708 # auth event.
709
710 # Step 1, fetch all existing links from all the chains we've seen
711 # referenced.
712 chain_links = _LinkMap()
713 rows = db_pool.simple_select_many_txn(
714 txn,
715 table="event_auth_chain_links",
716 column="origin_chain_id",
717 iterable={chain_id for chain_id, _ in chain_map.values()},
718 keyvalues={},
719 retcols=(
720 "origin_chain_id",
721 "origin_sequence_number",
722 "target_chain_id",
723 "target_sequence_number",
724 ),
725 )
726 for row in rows:
727 chain_links.add_link(
728 (row["origin_chain_id"], row["origin_sequence_number"]),
729 (row["target_chain_id"], row["target_sequence_number"]),
730 new=False,
731 )
732
733 # We do this in toplogical order to avoid adding redundant links.
734 for event_id in sorted_topologically(
735 events_to_calc_chain_id_for, event_to_auth_chain
736 ):
737 chain_id, sequence_number = chain_map[event_id]
738
739 # Filter out auth events that are reachable by other auth
740 # events. We do this by looking at every permutation of pairs of
741 # auth events (A, B) to check if B is reachable from A.
742 reduction = {
743 a_id
744 for a_id in event_to_auth_chain.get(event_id, [])
745 if chain_map[a_id][0] != chain_id
746 }
747 for start_auth_id, end_auth_id in itertools.permutations(
748 event_to_auth_chain.get(event_id, []), r=2,
749 ):
750 if chain_links.exists_path_from(
751 chain_map[start_auth_id], chain_map[end_auth_id]
752 ):
753 reduction.discard(end_auth_id)
754
755 # Step 2, figure out what the new links are from the reduced
756 # list of auth events.
757 for auth_id in reduction:
758 auth_chain_id, auth_sequence_number = chain_map[auth_id]
759
760 # Step 2a, add link between the event and auth event
761 chain_links.add_link(
762 (chain_id, sequence_number), (auth_chain_id, auth_sequence_number)
763 )
764
765 # Step 2b, add a link to chains reachable from the auth
766 # event.
767 for target_id, target_seq in chain_links.get_links_from(
768 (auth_chain_id, auth_sequence_number)
769 ):
770 if target_id == chain_id:
771 continue
772
773 chain_links.add_link(
774 (chain_id, sequence_number), (target_id, target_seq)
775 )
776
777 db_pool.simple_insert_many_txn(
778 txn,
779 table="event_auth_chain_links",
780 values=[
781 {
782 "origin_chain_id": source_id,
783 "origin_sequence_number": source_seq,
784 "target_chain_id": target_id,
785 "target_sequence_number": target_seq,
786 }
787 for (
788 source_id,
789 source_seq,
790 target_id,
791 target_seq,
792 ) in chain_links.get_additions()
793 ],
794 )
408795
409796 def _persist_transaction_ids_txn(
410797 self,
7981185 return [ec for ec in events_and_contexts if ec[0] not in to_remove]
7991186
8001187 def _store_event_txn(self, txn, events_and_contexts):
801 """Insert new events into the event and event_json tables
1188 """Insert new events into the event, event_json, redaction and
1189 state_events tables.
8021190
8031191 Args:
8041192 txn (twisted.enterprise.adbapi.Connection): db connection
8701258 updatevalues={"have_censored": False},
8711259 )
8721260
1261 state_events_and_contexts = [
1262 ec for ec in events_and_contexts if ec[0].is_state()
1263 ]
1264
1265 state_values = []
1266 for event, context in state_events_and_contexts:
1267 vals = {
1268 "event_id": event.event_id,
1269 "room_id": event.room_id,
1270 "type": event.type,
1271 "state_key": event.state_key,
1272 }
1273
1274 # TODO: How does this work with backfilling?
1275 if hasattr(event, "replaces_state"):
1276 vals["prev_state"] = event.replaces_state
1277
1278 state_values.append(vals)
1279
1280 self.db_pool.simple_insert_many_txn(
1281 txn, table="state_events", values=state_values
1282 )
1283
8731284 def _store_rejected_events_txn(self, txn, events_and_contexts):
8741285 """Add rows to the 'rejections' table for received events which were
8751286 rejected
9841395 # Insert event_reference_hashes table.
9851396 self._store_event_reference_hashes_txn(
9861397 txn, [event for event, _ in events_and_contexts]
987 )
988
989 state_events_and_contexts = [
990 ec for ec in events_and_contexts if ec[0].is_state()
991 ]
992
993 state_values = []
994 for event, context in state_events_and_contexts:
995 vals = {
996 "event_id": event.event_id,
997 "room_id": event.room_id,
998 "type": event.type,
999 "state_key": event.state_key,
1000 }
1001
1002 # TODO: How does this work with backfilling?
1003 if hasattr(event, "replaces_state"):
1004 vals["prev_state"] = event.replaces_state
1005
1006 state_values.append(vals)
1007
1008 self.db_pool.simple_insert_many_txn(
1009 txn, table="state_events", values=state_values
10101398 )
10111399
10121400 # Prefill the event cache
15191907 if not ev.internal_metadata.is_outlier()
15201908 ],
15211909 )
1910
1911
1912 @attr.s(slots=True)
1913 class _LinkMap:
1914 """A helper type for tracking links between chains.
1915 """
1916
1917 # Stores the set of links as nested maps: source chain ID -> target chain ID
1918 # -> source sequence number -> target sequence number.
1919 maps = attr.ib(type=Dict[int, Dict[int, Dict[int, int]]], factory=dict)
1920
1921 # Stores the links that have been added (with new set to true), as tuples of
1922 # `(source chain ID, source sequence no, target chain ID, target sequence no.)`
1923 additions = attr.ib(type=Set[Tuple[int, int, int, int]], factory=set)
1924
1925 def add_link(
1926 self,
1927 src_tuple: Tuple[int, int],
1928 target_tuple: Tuple[int, int],
1929 new: bool = True,
1930 ) -> bool:
1931 """Add a new link between two chains, ensuring no redundant links are added.
1932
1933 New links should be added in topological order.
1934
1935 Args:
1936 src_tuple: The chain ID/sequence number of the source of the link.
1937 target_tuple: The chain ID/sequence number of the target of the link.
1938 new: Whether this is a "new" link, i.e. should it be returned
1939 by `get_additions`.
1940
1941 Returns:
1942 True if a link was added, false if the given link was dropped as redundant
1943 """
1944 src_chain, src_seq = src_tuple
1945 target_chain, target_seq = target_tuple
1946
1947 current_links = self.maps.setdefault(src_chain, {}).setdefault(target_chain, {})
1948
1949 assert src_chain != target_chain
1950
1951 if new:
1952 # Check if the new link is redundant
1953 for current_seq_src, current_seq_target in current_links.items():
1954 # If a link "crosses" another link then its redundant. For example
1955 # in the following link 1 (L1) is redundant, as any event reachable
1956 # via L1 is *also* reachable via L2.
1957 #
1958 # Chain A Chain B
1959 # | |
1960 # L1 |------ |
1961 # | | |
1962 # L2 |---- | -->|
1963 # | | |
1964 # | |--->|
1965 # | |
1966 # | |
1967 #
1968 # So we only need to keep links which *do not* cross, i.e. links
1969 # that both start and end above or below an existing link.
1970 #
1971 # Note, since we add links in topological ordering we should never
1972 # see `src_seq` less than `current_seq_src`.
1973
1974 if current_seq_src <= src_seq and target_seq <= current_seq_target:
1975 # This new link is redundant, nothing to do.
1976 return False
1977
1978 self.additions.add((src_chain, src_seq, target_chain, target_seq))
1979
1980 current_links[src_seq] = target_seq
1981 return True
1982
1983 def get_links_from(
1984 self, src_tuple: Tuple[int, int]
1985 ) -> Generator[Tuple[int, int], None, None]:
1986 """Gets the chains reachable from the given chain/sequence number.
1987
1988 Yields:
1989 The chain ID and sequence number the link points to.
1990 """
1991 src_chain, src_seq = src_tuple
1992 for target_id, sequence_numbers in self.maps.get(src_chain, {}).items():
1993 for link_src_seq, target_seq in sequence_numbers.items():
1994 if link_src_seq <= src_seq:
1995 yield target_id, target_seq
1996
1997 def get_links_between(
1998 self, source_chain: int, target_chain: int
1999 ) -> Generator[Tuple[int, int], None, None]:
2000 """Gets the links between two chains.
2001
2002 Yields:
2003 The source and target sequence numbers.
2004 """
2005
2006 yield from self.maps.get(source_chain, {}).get(target_chain, {}).items()
2007
2008 def get_additions(self) -> Generator[Tuple[int, int, int, int], None, None]:
2009 """Gets any newly added links.
2010
2011 Yields:
2012 The source chain ID/sequence number and target chain ID/sequence number
2013 """
2014
2015 for src_chain, src_seq, target_chain, _ in self.additions:
2016 target_seq = self.maps.get(src_chain, {}).get(target_chain, {}).get(src_seq)
2017 if target_seq is not None:
2018 yield (src_chain, src_seq, target_chain, target_seq)
2019
2020 def exists_path_from(
2021 self, src_tuple: Tuple[int, int], target_tuple: Tuple[int, int],
2022 ) -> bool:
2023 """Checks if there is a path between the source chain ID/sequence and
2024 target chain ID/sequence.
2025 """
2026 src_chain, src_seq = src_tuple
2027 target_chain, target_seq = target_tuple
2028
2029 if src_chain == target_chain:
2030 return target_seq <= src_seq
2031
2032 links = self.get_links_between(src_chain, target_chain)
2033 for link_start_seq, link_end_seq in links:
2034 if link_start_seq <= src_seq and target_seq <= link_end_seq:
2035 return True
2036
2037 return False
1313 # limitations under the License.
1414
1515 import logging
16 from typing import Dict, List, Optional, Tuple
17
18 import attr
1619
1720 from synapse.api.constants import EventContentFields
21 from synapse.api.room_versions import KNOWN_ROOM_VERSIONS
22 from synapse.events import make_event_from_dict
1823 from synapse.storage._base import SQLBaseStore, db_to_json, make_in_list_sql_clause
19 from synapse.storage.database import DatabasePool
24 from synapse.storage.database import DatabasePool, make_tuple_comparison_clause
25 from synapse.storage.databases.main.events import PersistEventsStore
26 from synapse.storage.types import Cursor
27 from synapse.types import JsonDict
2028
2129 logger = logging.getLogger(__name__)
30
31
32 @attr.s(slots=True, frozen=True)
33 class _CalculateChainCover:
34 """Return value for _calculate_chain_cover_txn.
35 """
36
37 # The last room_id/depth/stream processed.
38 room_id = attr.ib(type=str)
39 depth = attr.ib(type=int)
40 stream = attr.ib(type=int)
41
42 # Number of rows processed
43 processed_count = attr.ib(type=int)
44
45 # Map from room_id to last depth/stream processed for each room that we have
46 # processed all events for (i.e. the rooms we can flip the
47 # `has_auth_chain_index` for)
48 finished_room_map = attr.ib(type=Dict[str, Tuple[int, int]])
2249
2350
2451 class EventsBackgroundUpdatesStore(SQLBaseStore):
96123 index_name="users_have_local_media",
97124 table="local_media_repository",
98125 columns=["user_id", "created_ts"],
126 )
127
128 self.db_pool.updates.register_background_update_handler(
129 "rejected_events_metadata", self._rejected_events_metadata,
130 )
131
132 self.db_pool.updates.register_background_update_handler(
133 "chain_cover", self._chain_cover_index,
99134 )
100135
101136 async def _background_reindex_fields_sender(self, progress, batch_size):
581616 await self.db_pool.updates._end_background_update("event_store_labels")
582617
583618 return num_rows
619
620 async def _rejected_events_metadata(self, progress: dict, batch_size: int) -> int:
621 """Adds rejected events to the `state_events` and `event_auth` metadata
622 tables.
623 """
624
625 last_event_id = progress.get("last_event_id", "")
626
627 def get_rejected_events(
628 txn: Cursor,
629 ) -> List[Tuple[str, str, JsonDict, bool, bool]]:
630 # Fetch rejected event json, their room version and whether we have
631 # inserted them into the state_events or auth_events tables.
632 #
633 # Note we can assume that events that don't have a corresponding
634 # room version are V1 rooms.
635 sql = """
636 SELECT DISTINCT
637 event_id,
638 COALESCE(room_version, '1'),
639 json,
640 state_events.event_id IS NOT NULL,
641 event_auth.event_id IS NOT NULL
642 FROM rejections
643 INNER JOIN event_json USING (event_id)
644 LEFT JOIN rooms USING (room_id)
645 LEFT JOIN state_events USING (event_id)
646 LEFT JOIN event_auth USING (event_id)
647 WHERE event_id > ?
648 ORDER BY event_id
649 LIMIT ?
650 """
651
652 txn.execute(sql, (last_event_id, batch_size,))
653
654 return [(row[0], row[1], db_to_json(row[2]), row[3], row[4]) for row in txn] # type: ignore
655
656 results = await self.db_pool.runInteraction(
657 desc="_rejected_events_metadata_get", func=get_rejected_events
658 )
659
660 if not results:
661 await self.db_pool.updates._end_background_update(
662 "rejected_events_metadata"
663 )
664 return 0
665
666 state_events = []
667 auth_events = []
668 for event_id, room_version, event_json, has_state, has_event_auth in results:
669 last_event_id = event_id
670
671 if has_state and has_event_auth:
672 continue
673
674 room_version_obj = KNOWN_ROOM_VERSIONS.get(room_version)
675 if not room_version_obj:
676 # We no longer support this room version, so we just ignore the
677 # events entirely.
678 logger.info(
679 "Ignoring event with unknown room version %r: %r",
680 room_version,
681 event_id,
682 )
683 continue
684
685 event = make_event_from_dict(event_json, room_version_obj)
686
687 if not event.is_state():
688 continue
689
690 if not has_state:
691 state_events.append(
692 {
693 "event_id": event.event_id,
694 "room_id": event.room_id,
695 "type": event.type,
696 "state_key": event.state_key,
697 }
698 )
699
700 if not has_event_auth:
701 for auth_id in event.auth_event_ids():
702 auth_events.append(
703 {
704 "room_id": event.room_id,
705 "event_id": event.event_id,
706 "auth_id": auth_id,
707 }
708 )
709
710 if state_events:
711 await self.db_pool.simple_insert_many(
712 table="state_events",
713 values=state_events,
714 desc="_rejected_events_metadata_state_events",
715 )
716
717 if auth_events:
718 await self.db_pool.simple_insert_many(
719 table="event_auth",
720 values=auth_events,
721 desc="_rejected_events_metadata_event_auth",
722 )
723
724 await self.db_pool.updates._background_update_progress(
725 "rejected_events_metadata", {"last_event_id": last_event_id}
726 )
727
728 if len(results) < batch_size:
729 await self.db_pool.updates._end_background_update(
730 "rejected_events_metadata"
731 )
732
733 return len(results)
734
735 async def _chain_cover_index(self, progress: dict, batch_size: int) -> int:
736 """A background updates that iterates over all rooms and generates the
737 chain cover index for them.
738 """
739
740 current_room_id = progress.get("current_room_id", "")
741
742 # Where we've processed up to in the room, defaults to the start of the
743 # room.
744 last_depth = progress.get("last_depth", -1)
745 last_stream = progress.get("last_stream", -1)
746
747 result = await self.db_pool.runInteraction(
748 "_chain_cover_index",
749 self._calculate_chain_cover_txn,
750 current_room_id,
751 last_depth,
752 last_stream,
753 batch_size,
754 single_room=False,
755 )
756
757 finished = result.processed_count == 0
758
759 total_rows_processed = result.processed_count
760 current_room_id = result.room_id
761 last_depth = result.depth
762 last_stream = result.stream
763
764 for room_id, (depth, stream) in result.finished_room_map.items():
765 # If we've done all the events in the room we flip the
766 # `has_auth_chain_index` in the DB. Note that its possible for
767 # further events to be persisted between the above and setting the
768 # flag without having the chain cover calculated for them. This is
769 # fine as a) the code gracefully handles these cases and b) we'll
770 # calculate them below.
771
772 await self.db_pool.simple_update(
773 table="rooms",
774 keyvalues={"room_id": room_id},
775 updatevalues={"has_auth_chain_index": True},
776 desc="_chain_cover_index",
777 )
778
779 # Handle any events that might have raced with us flipping the
780 # bit above.
781 result = await self.db_pool.runInteraction(
782 "_chain_cover_index",
783 self._calculate_chain_cover_txn,
784 room_id,
785 depth,
786 stream,
787 batch_size=None,
788 single_room=True,
789 )
790
791 total_rows_processed += result.processed_count
792
793 if finished:
794 await self.db_pool.updates._end_background_update("chain_cover")
795 return total_rows_processed
796
797 await self.db_pool.updates._background_update_progress(
798 "chain_cover",
799 {
800 "current_room_id": current_room_id,
801 "last_depth": last_depth,
802 "last_stream": last_stream,
803 },
804 )
805
806 return total_rows_processed
807
808 def _calculate_chain_cover_txn(
809 self,
810 txn: Cursor,
811 last_room_id: str,
812 last_depth: int,
813 last_stream: int,
814 batch_size: Optional[int],
815 single_room: bool,
816 ) -> _CalculateChainCover:
817 """Calculate the chain cover for `batch_size` events, ordered by
818 `(room_id, depth, stream)`.
819
820 Args:
821 txn,
822 last_room_id, last_depth, last_stream: The `(room_id, depth, stream)`
823 tuple to fetch results after.
824 batch_size: The maximum number of events to process. If None then
825 no limit.
826 single_room: Whether to calculate the index for just the given
827 room.
828 """
829
830 # Get the next set of events in the room (that we haven't already
831 # computed chain cover for). We do this in topological order.
832
833 # We want to do a `(topological_ordering, stream_ordering) > (?,?)`
834 # comparison, but that is not supported on older SQLite versions
835 tuple_clause, tuple_args = make_tuple_comparison_clause(
836 self.database_engine,
837 [
838 ("events.room_id", last_room_id),
839 ("topological_ordering", last_depth),
840 ("stream_ordering", last_stream),
841 ],
842 )
843
844 extra_clause = ""
845 if single_room:
846 extra_clause = "AND events.room_id = ?"
847 tuple_args.append(last_room_id)
848
849 sql = """
850 SELECT
851 event_id, state_events.type, state_events.state_key,
852 topological_ordering, stream_ordering,
853 events.room_id
854 FROM events
855 INNER JOIN state_events USING (event_id)
856 LEFT JOIN event_auth_chains USING (event_id)
857 LEFT JOIN event_auth_chain_to_calculate USING (event_id)
858 WHERE event_auth_chains.event_id IS NULL
859 AND event_auth_chain_to_calculate.event_id IS NULL
860 AND %(tuple_cmp)s
861 %(extra)s
862 ORDER BY events.room_id, topological_ordering, stream_ordering
863 %(limit)s
864 """ % {
865 "tuple_cmp": tuple_clause,
866 "limit": "LIMIT ?" if batch_size is not None else "",
867 "extra": extra_clause,
868 }
869
870 if batch_size is not None:
871 tuple_args.append(batch_size)
872
873 txn.execute(sql, tuple_args)
874 rows = txn.fetchall()
875
876 # Put the results in the necessary format for
877 # `_add_chain_cover_index`
878 event_to_room_id = {row[0]: row[5] for row in rows}
879 event_to_types = {row[0]: (row[1], row[2]) for row in rows}
880
881 # Calculate the new last position we've processed up to.
882 new_last_depth = rows[-1][3] if rows else last_depth # type: int
883 new_last_stream = rows[-1][4] if rows else last_stream # type: int
884 new_last_room_id = rows[-1][5] if rows else "" # type: str
885
886 # Map from room_id to last depth/stream_ordering processed for the room,
887 # excluding the last room (which we're likely still processing). We also
888 # need to include the room passed in if it's not included in the result
889 # set (as we then know we've processed all events in said room).
890 #
891 # This is the set of rooms that we can now safely flip the
892 # `has_auth_chain_index` bit for.
893 finished_rooms = {
894 row[5]: (row[3], row[4]) for row in rows if row[5] != new_last_room_id
895 }
896 if last_room_id not in finished_rooms and last_room_id != new_last_room_id:
897 finished_rooms[last_room_id] = (last_depth, last_stream)
898
899 count = len(rows)
900
901 # We also need to fetch the auth events for them.
902 auth_events = self.db_pool.simple_select_many_txn(
903 txn,
904 table="event_auth",
905 column="event_id",
906 iterable=event_to_room_id,
907 keyvalues={},
908 retcols=("event_id", "auth_id"),
909 )
910
911 event_to_auth_chain = {} # type: Dict[str, List[str]]
912 for row in auth_events:
913 event_to_auth_chain.setdefault(row["event_id"], []).append(row["auth_id"])
914
915 # Calculate and persist the chain cover index for this set of events.
916 #
917 # Annoyingly we need to gut wrench into the persit event store so that
918 # we can reuse the function to calculate the chain cover for rooms.
919 PersistEventsStore._add_chain_cover_index(
920 txn, self.db_pool, event_to_room_id, event_to_types, event_to_auth_chain,
921 )
922
923 return _CalculateChainCover(
924 room_id=new_last_room_id,
925 depth=new_last_depth,
926 stream=new_last_stream,
927 processed_count=count,
928 finished_room_map=finished_rooms,
929 )
9595 db=database,
9696 stream_name="events",
9797 instance_name=hs.get_instance_name(),
98 table="events",
99 instance_column="instance_name",
100 id_column="stream_ordering",
98 tables=[("events", "instance_name", "stream_ordering")],
10199 sequence_name="events_stream_seq",
102100 writers=hs.config.worker.writers.events,
103101 )
106104 db=database,
107105 stream_name="backfill",
108106 instance_name=hs.get_instance_name(),
109 table="events",
110 instance_column="instance_name",
111 id_column="stream_ordering",
107 tables=[("events", "instance_name", "stream_ordering")],
112108 sequence_name="events_backfill_stream_seq",
113109 positive=False,
114110 writers=hs.config.worker.writers.events,
00 # -*- coding: utf-8 -*-
11 # Copyright 2014-2016 OpenMarket Ltd
2 # Copyright 2020-2021 The Matrix.org Foundation C.I.C.
23 #
34 # Licensed under the Apache License, Version 2.0 (the "License");
45 # you may not use this file except in compliance with the License.
168169
169170 async def get_local_media_before(
170171 self, before_ts: int, size_gt: int, keep_profiles: bool,
171 ) -> Optional[List[str]]:
172 ) -> List[str]:
172173
173174 # to find files that have never been accessed (last_access_ts IS NULL)
174175 # compare with `created_ts`
8181 )
8282
8383 async def set_profile_avatar_url(
84 self, user_localpart: str, new_avatar_url: str
84 self, user_localpart: str, new_avatar_url: Optional[str]
8585 ) -> None:
8686 await self.db_pool.simple_update_one(
8787 table="profiles",
1616 import logging
1717 from typing import TYPE_CHECKING, Any, Dict, Iterable, Iterator, List, Optional, Tuple
1818
19 from canonicaljson import encode_canonical_json
20
2119 from synapse.push import PusherConfig, ThrottleParams
2220 from synapse.storage._base import SQLBaseStore, db_to_json
2321 from synapse.storage.database import DatabasePool
2422 from synapse.storage.types import Connection
2523 from synapse.storage.util.id_generators import StreamIdGenerator
2624 from synapse.types import JsonDict
25 from synapse.util import json_encoder
2726 from synapse.util.caches.descriptors import cached, cachedList
2827
2928 if TYPE_CHECKING:
314313 "device_display_name": device_display_name,
315314 "ts": pushkey_ts,
316315 "lang": lang,
317 "data": bytearray(encode_canonical_json(data)),
316 "data": json_encoder.encode(data),
318317 "last_stream_ordering": last_stream_ordering,
319318 "profile_tag": profile_tag,
320319 "id": stream_id,
1313 # See the License for the specific language governing permissions and
1414 # limitations under the License.
1515
16 import abc
1716 import logging
1817 from typing import Any, Dict, List, Optional, Tuple
1918
2019 from twisted.internet import defer
2120
21 from synapse.replication.slave.storage._slaved_id_tracker import SlavedIdTracker
22 from synapse.replication.tcp.streams import ReceiptsStream
2223 from synapse.storage._base import SQLBaseStore, db_to_json, make_in_list_sql_clause
2324 from synapse.storage.database import DatabasePool
24 from synapse.storage.util.id_generators import StreamIdGenerator
25 from synapse.storage.engines import PostgresEngine
26 from synapse.storage.util.id_generators import MultiWriterIdGenerator, StreamIdGenerator
2527 from synapse.types import JsonDict
2628 from synapse.util import json_encoder
2729 from synapse.util.caches.descriptors import cached, cachedList
3032 logger = logging.getLogger(__name__)
3133
3234
33 # The ABCMeta metaclass ensures that it cannot be instantiated without
34 # the abstract methods being implemented.
35 class ReceiptsWorkerStore(SQLBaseStore, metaclass=abc.ABCMeta):
36 """This is an abstract base class where subclasses must implement
37 `get_max_receipt_stream_id` which can be called in the initializer.
38 """
39
35 class ReceiptsWorkerStore(SQLBaseStore):
4036 def __init__(self, database: DatabasePool, db_conn, hs):
37 self._instance_name = hs.get_instance_name()
38
39 if isinstance(database.engine, PostgresEngine):
40 self._can_write_to_receipts = (
41 self._instance_name in hs.config.worker.writers.receipts
42 )
43
44 self._receipts_id_gen = MultiWriterIdGenerator(
45 db_conn=db_conn,
46 db=database,
47 stream_name="receipts",
48 instance_name=self._instance_name,
49 tables=[("receipts_linearized", "instance_name", "stream_id")],
50 sequence_name="receipts_sequence",
51 writers=hs.config.worker.writers.receipts,
52 )
53 else:
54 self._can_write_to_receipts = True
55
56 # We shouldn't be running in worker mode with SQLite, but its useful
57 # to support it for unit tests.
58 #
59 # If this process is the writer than we need to use
60 # `StreamIdGenerator`, otherwise we use `SlavedIdTracker` which gets
61 # updated over replication. (Multiple writers are not supported for
62 # SQLite).
63 if hs.get_instance_name() in hs.config.worker.writers.receipts:
64 self._receipts_id_gen = StreamIdGenerator(
65 db_conn, "receipts_linearized", "stream_id"
66 )
67 else:
68 self._receipts_id_gen = SlavedIdTracker(
69 db_conn, "receipts_linearized", "stream_id"
70 )
71
4172 super().__init__(database, db_conn, hs)
4273
4374 self._receipts_stream_cache = StreamChangeCache(
4475 "ReceiptsRoomChangeCache", self.get_max_receipt_stream_id()
4576 )
4677
47 @abc.abstractmethod
4878 def get_max_receipt_stream_id(self):
4979 """Get the current max stream ID for receipts stream
5080
5181 Returns:
5282 int
5383 """
54 raise NotImplementedError()
84 return self._receipts_id_gen.get_current_token()
5585
5686 @cached()
5787 async def get_users_with_read_receipts_in_room(self, room_id):
427457
428458 self.get_users_with_read_receipts_in_room.invalidate((room_id,))
429459
430
431 class ReceiptsStore(ReceiptsWorkerStore):
432 def __init__(self, database: DatabasePool, db_conn, hs):
433 # We instantiate this first as the ReceiptsWorkerStore constructor
434 # needs to be able to call get_max_receipt_stream_id
435 self._receipts_id_gen = StreamIdGenerator(
436 db_conn, "receipts_linearized", "stream_id"
437 )
438
439 super().__init__(database, db_conn, hs)
440
441 def get_max_receipt_stream_id(self):
442 return self._receipts_id_gen.get_current_token()
460 def invalidate_caches_for_receipt(self, room_id, receipt_type, user_id):
461 self.get_receipts_for_user.invalidate((user_id, receipt_type))
462 self._get_linearized_receipts_for_room.invalidate_many((room_id,))
463 self.get_last_receipt_event_id_for_user.invalidate(
464 (user_id, room_id, receipt_type)
465 )
466 self._invalidate_get_users_with_receipts_in_room(room_id, receipt_type, user_id)
467 self.get_receipts_for_room.invalidate((room_id, receipt_type))
468
469 def process_replication_rows(self, stream_name, instance_name, token, rows):
470 if stream_name == ReceiptsStream.NAME:
471 self._receipts_id_gen.advance(instance_name, token)
472 for row in rows:
473 self.invalidate_caches_for_receipt(
474 row.room_id, row.receipt_type, row.user_id
475 )
476 self._receipts_stream_cache.entity_has_changed(row.room_id, token)
477
478 return super().process_replication_rows(stream_name, instance_name, token, rows)
443479
444480 def insert_linearized_receipt_txn(
445481 self, txn, room_id, receipt_type, user_id, event_id, data, stream_id
451487 otherwise, the rx timestamp of the event that the RR corresponds to
452488 (or 0 if the event is unknown)
453489 """
490 assert self._can_write_to_receipts
491
454492 res = self.db_pool.simple_select_one_txn(
455493 txn,
456494 table="events",
482520 )
483521 return None
484522
485 txn.call_after(self.get_receipts_for_room.invalidate, (room_id, receipt_type))
486523 txn.call_after(
487 self._invalidate_get_users_with_receipts_in_room,
488 room_id,
489 receipt_type,
490 user_id,
491 )
492 txn.call_after(self.get_receipts_for_user.invalidate, (user_id, receipt_type))
493 # FIXME: This shouldn't invalidate the whole cache
494 txn.call_after(
495 self._get_linearized_receipts_for_room.invalidate_many, (room_id,)
524 self.invalidate_caches_for_receipt, room_id, receipt_type, user_id
496525 )
497526
498527 txn.call_after(
499528 self._receipts_stream_cache.entity_has_changed, room_id, stream_id
500 )
501
502 txn.call_after(
503 self.get_last_receipt_event_id_for_user.invalidate,
504 (user_id, room_id, receipt_type),
505529 )
506530
507531 self.db_pool.simple_upsert_txn(
542566 Automatically does conversion between linearized and graph
543567 representations.
544568 """
569 assert self._can_write_to_receipts
570
545571 if not event_ids:
546572 return None
547573
606632 async def insert_graph_receipt(
607633 self, room_id, receipt_type, user_id, event_ids, data
608634 ):
635 assert self._can_write_to_receipts
636
609637 return await self.db_pool.runInteraction(
610638 "insert_graph_receipt",
611639 self.insert_graph_receipt_txn,
619647 def insert_graph_receipt_txn(
620648 self, txn, room_id, receipt_type, user_id, event_ids, data
621649 ):
650 assert self._can_write_to_receipts
651
622652 txn.call_after(self.get_receipts_for_room.invalidate, (room_id, receipt_type))
623653 txn.call_after(
624654 self._invalidate_get_users_with_receipts_in_room,
652682 "data": json_encoder.encode(data),
653683 },
654684 )
685
686
687 class ReceiptsStore(ReceiptsWorkerStore):
688 pass
1515
1616 import collections
1717 import logging
18 import re
1918 from abc import abstractmethod
2019 from enum import Enum
2120 from typing import Any, Dict, List, Optional, Tuple
2928 from synapse.types import JsonDict, ThirdPartyInstanceID
3029 from synapse.util import json_encoder
3130 from synapse.util.caches.descriptors import cached
31 from synapse.util.stringutils import MXC_REGEX
3232
3333 logger = logging.getLogger(__name__)
3434
8383 return await self.db_pool.simple_select_one(
8484 table="rooms",
8585 keyvalues={"room_id": room_id},
86 retcols=("room_id", "is_public", "creator"),
86 retcols=("room_id", "is_public", "creator", "has_auth_chain_index"),
8787 desc="get_room",
8888 allow_none=True,
8989 )
659659 The local and remote media as a lists of tuples where the key is
660660 the hostname and the value is the media ID.
661661 """
662 mxc_re = re.compile("^mxc://([^/]+)/([^/#?]+)")
663
664662 sql = """
665663 SELECT stream_ordering, json FROM events
666664 JOIN event_json USING (room_id, event_id)
687685 for url in (content_url, thumbnail_url):
688686 if not url:
689687 continue
690 matches = mxc_re.match(url)
688 matches = MXC_REGEX.match(url)
691689 if matches:
692690 hostname = matches.group(1)
693691 media_id = matches.group(2)
11651163 # It's overridden by RoomStore for the synapse master.
11661164 raise NotImplementedError()
11671165
1166 async def has_auth_chain_index(self, room_id: str) -> bool:
1167 """Check if the room has (or can have) a chain cover index.
1168
1169 Defaults to True if we don't have an entry in `rooms` table nor any
1170 events for the room.
1171 """
1172
1173 has_auth_chain_index = await self.db_pool.simple_select_one_onecol(
1174 table="rooms",
1175 keyvalues={"room_id": room_id},
1176 retcol="has_auth_chain_index",
1177 desc="has_auth_chain_index",
1178 allow_none=True,
1179 )
1180
1181 if has_auth_chain_index:
1182 return True
1183
1184 # It's possible that we already have events for the room in our DB
1185 # without a corresponding room entry. If we do then we don't want to
1186 # mark the room as having an auth chain cover index.
1187 max_ordering = await self.db_pool.simple_select_one_onecol(
1188 table="events",
1189 keyvalues={"room_id": room_id},
1190 retcol="MAX(stream_ordering)",
1191 allow_none=True,
1192 desc="upsert_room_on_join",
1193 )
1194
1195 return max_ordering is None
1196
11681197
11691198 class RoomStore(RoomBackgroundUpdateStore, RoomWorkerStore, SearchStore):
11701199 def __init__(self, database: DatabasePool, db_conn, hs):
11781207 Called when we join a room over federation, and overwrites any room version
11791208 currently in the table.
11801209 """
1210 # It's possible that we already have events for the room in our DB
1211 # without a corresponding room entry. If we do then we don't want to
1212 # mark the room as having an auth chain cover index.
1213 has_auth_chain_index = await self.has_auth_chain_index(room_id)
1214
11811215 await self.db_pool.simple_upsert(
11821216 desc="upsert_room_on_join",
11831217 table="rooms",
11841218 keyvalues={"room_id": room_id},
11851219 values={"room_version": room_version.identifier},
1186 insertion_values={"is_public": False, "creator": ""},
1220 insertion_values={
1221 "is_public": False,
1222 "creator": "",
1223 "has_auth_chain_index": has_auth_chain_index,
1224 },
11871225 # rooms has a unique constraint on room_id, so no need to lock when doing an
11881226 # emulated upsert.
11891227 lock=False,
12181256 "creator": room_creator_user_id,
12191257 "is_public": is_public,
12201258 "room_version": room_version.identifier,
1259 "has_auth_chain_index": True,
12211260 },
12221261 )
12231262 if is_public:
12461285 When we receive an invite or any other event over federation that may relate to a room
12471286 we are not in, store the version of the room if we don't already know the room version.
12481287 """
1288 # It's possible that we already have events for the room in our DB
1289 # without a corresponding room entry. If we do then we don't want to
1290 # mark the room as having an auth chain cover index.
1291 has_auth_chain_index = await self.has_auth_chain_index(room_id)
1292
12491293 await self.db_pool.simple_upsert(
12501294 desc="maybe_store_room_on_outlier_membership",
12511295 table="rooms",
12551299 "room_version": room_version.identifier,
12561300 "is_public": False,
12571301 "creator": "",
1302 "has_auth_chain_index": has_auth_chain_index,
12581303 },
12591304 # rooms has a unique constraint on room_id, so no need to lock when doing an
12601305 # emulated upsert.
0 /* Copyright 2020 The Matrix.org Foundation C.I.C
1 *
2 * Licensed under the Apache License, Version 2.0 (the "License");
3 * you may not use this file except in compliance with the License.
4 * You may obtain a copy of the License at
5 *
6 * http://www.apache.org/licenses/LICENSE-2.0
7 *
8 * Unless required by applicable law or agreed to in writing, software
9 * distributed under the License is distributed on an "AS IS" BASIS,
10 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
11 * See the License for the specific language governing permissions and
12 * limitations under the License.
13 */
14
15 ALTER TABLE access_tokens DROP COLUMN last_used;
0 /*
1 * Copyright 2020 The Matrix.org Foundation C.I.C.
2 *
3 * Licensed under the Apache License, Version 2.0 (the "License");
4 * you may not use this file except in compliance with the License.
5 * You may obtain a copy of the License at
6 *
7 * http://www.apache.org/licenses/LICENSE-2.0
8 *
9 * Unless required by applicable law or agreed to in writing, software
10 * distributed under the License is distributed on an "AS IS" BASIS,
11 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 * See the License for the specific language governing permissions and
13 * limitations under the License.
14 */
15
16 -- Dropping last_used column from access_tokens table.
17
18 CREATE TABLE access_tokens2 (
19 id BIGINT PRIMARY KEY,
20 user_id TEXT NOT NULL,
21 device_id TEXT,
22 token TEXT NOT NULL,
23 valid_until_ms BIGINT,
24 puppets_user_id TEXT,
25 last_validated BIGINT,
26 UNIQUE(token)
27 );
28
29 INSERT INTO access_tokens2(id, user_id, device_id, token)
30 SELECT id, user_id, device_id, token FROM access_tokens;
31
32 DROP TABLE access_tokens;
33 ALTER TABLE access_tokens2 RENAME TO access_tokens;
34
35 CREATE INDEX access_tokens_device_id ON access_tokens (user_id, device_id);
36
37
38 -- Re-adding foreign key reference in event_txn_id table
39
40 CREATE TABLE event_txn_id2 (
41 event_id TEXT NOT NULL,
42 room_id TEXT NOT NULL,
43 user_id TEXT NOT NULL,
44 token_id BIGINT NOT NULL,
45 txn_id TEXT NOT NULL,
46 inserted_ts BIGINT NOT NULL,
47 FOREIGN KEY (event_id)
48 REFERENCES events (event_id) ON DELETE CASCADE,
49 FOREIGN KEY (token_id)
50 REFERENCES access_tokens (id) ON DELETE CASCADE
51 );
52
53 INSERT INTO event_txn_id2(event_id, room_id, user_id, token_id, txn_id, inserted_ts)
54 SELECT event_id, room_id, user_id, token_id, txn_id, inserted_ts FROM event_txn_id;
55
56 DROP TABLE event_txn_id;
57 ALTER TABLE event_txn_id2 RENAME TO event_txn_id;
58
59 CREATE UNIQUE INDEX IF NOT EXISTS event_txn_id_event_id ON event_txn_id(event_id);
60 CREATE UNIQUE INDEX IF NOT EXISTS event_txn_id_txn_id ON event_txn_id(room_id, user_id, token_id, txn_id);
61 CREATE INDEX IF NOT EXISTS event_txn_id_ts ON event_txn_id(inserted_ts);
0 /* Copyright 2020 The Matrix.org Foundation C.I.C
1 *
2 * Licensed under the Apache License, Version 2.0 (the "License");
3 * you may not use this file except in compliance with the License.
4 * You may obtain a copy of the License at
5 *
6 * http://www.apache.org/licenses/LICENSE-2.0
7 *
8 * Unless required by applicable law or agreed to in writing, software
9 * distributed under the License is distributed on an "AS IS" BASIS,
10 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
11 * See the License for the specific language governing permissions and
12 * limitations under the License.
13 */
14
15 INSERT INTO background_updates (ordering, update_name, progress_json) VALUES
16 (5828, 'rejected_events_metadata', '{}');
0 # Copyright 2021 The Matrix.org Foundation C.I.C.
1 #
2 # Licensed under the Apache License, Version 2.0 (the "License");
3 # you may not use this file except in compliance with the License.
4 # You may obtain a copy of the License at
5 #
6 # http://www.apache.org/licenses/LICENSE-2.0
7 #
8 # Unless required by applicable law or agreed to in writing, software
9 # distributed under the License is distributed on an "AS IS" BASIS,
10 # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
11 # See the License for the specific language governing permissions and
12 # limitations under the License.
13
14 """
15 This migration denormalises the account_data table into an ignored users table.
16 """
17
18 import logging
19 from io import StringIO
20
21 from synapse.storage._base import db_to_json
22 from synapse.storage.engines import BaseDatabaseEngine
23 from synapse.storage.prepare_database import execute_statements_from_stream
24 from synapse.storage.types import Cursor
25
26 logger = logging.getLogger(__name__)
27
28
29 def run_upgrade(cur: Cursor, database_engine: BaseDatabaseEngine, *args, **kwargs):
30 pass
31
32
33 def run_create(cur: Cursor, database_engine: BaseDatabaseEngine, *args, **kwargs):
34 logger.info("Creating ignored_users table")
35 execute_statements_from_stream(cur, StringIO(_create_commands))
36
37 # We now upgrade existing data, if any. We don't do this in `run_upgrade` as
38 # we a) want to run these before adding constraints and b) `run_upgrade` is
39 # not run on empty databases.
40 insert_sql = """
41 INSERT INTO ignored_users (ignorer_user_id, ignored_user_id) VALUES (?, ?)
42 """
43
44 logger.info("Converting existing ignore lists")
45 cur.execute(
46 "SELECT user_id, content FROM account_data WHERE account_data_type = 'm.ignored_user_list'"
47 )
48 for user_id, content_json in cur.fetchall():
49 content = db_to_json(content_json)
50
51 # The content should be the form of a dictionary with a key
52 # "ignored_users" pointing to a dictionary with keys of ignored users.
53 #
54 # { "ignored_users": "@someone:example.org": {} }
55 ignored_users = content.get("ignored_users", {})
56 if isinstance(ignored_users, dict) and ignored_users:
57 cur.executemany(insert_sql, [(user_id, u) for u in ignored_users])
58
59 # Add indexes after inserting data for efficiency.
60 logger.info("Adding constraints to ignored_users table")
61 execute_statements_from_stream(cur, StringIO(_constraints_commands))
62
63
64 # there might be duplicates, so the easiest way to achieve this is to create a new
65 # table with the right data, and renaming it into place
66
67 _create_commands = """
68 -- Users which are ignored when calculating push notifications. This data is
69 -- denormalized from account data.
70 CREATE TABLE IF NOT EXISTS ignored_users(
71 ignorer_user_id TEXT NOT NULL, -- The user ID of the user who is ignoring another user. (This is a local user.)
72 ignored_user_id TEXT NOT NULL -- The user ID of the user who is being ignored. (This is a local or remote user.)
73 );
74 """
75
76 _constraints_commands = """
77 CREATE UNIQUE INDEX ignored_users_uniqueness ON ignored_users (ignorer_user_id, ignored_user_id);
78
79 -- Add an index on ignored_users since look-ups are done to get all ignorers of an ignored user.
80 CREATE INDEX ignored_users_ignored_user_id ON ignored_users (ignored_user_id);
81 """
0 /* Copyright 2021 The Matrix.org Foundation C.I.C
1 *
2 * Licensed under the Apache License, Version 2.0 (the "License");
3 * you may not use this file except in compliance with the License.
4 * You may obtain a copy of the License at
5 *
6 * http://www.apache.org/licenses/LICENSE-2.0
7 *
8 * Unless required by applicable law or agreed to in writing, software
9 * distributed under the License is distributed on an "AS IS" BASIS,
10 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
11 * See the License for the specific language governing permissions and
12 * limitations under the License.
13 */
14
15 ALTER TABLE device_inbox ADD COLUMN instance_name TEXT;
16 ALTER TABLE device_federation_inbox ADD COLUMN instance_name TEXT;
17 ALTER TABLE device_federation_outbox ADD COLUMN instance_name TEXT;
0 /* Copyright 2021 The Matrix.org Foundation C.I.C
1 *
2 * Licensed under the Apache License, Version 2.0 (the "License");
3 * you may not use this file except in compliance with the License.
4 * You may obtain a copy of the License at
5 *
6 * http://www.apache.org/licenses/LICENSE-2.0
7 *
8 * Unless required by applicable law or agreed to in writing, software
9 * distributed under the License is distributed on an "AS IS" BASIS,
10 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
11 * See the License for the specific language governing permissions and
12 * limitations under the License.
13 */
14
15 CREATE SEQUENCE IF NOT EXISTS device_inbox_sequence;
16
17 -- We need to take the max across both device_inbox and device_federation_outbox
18 -- tables as they share the ID generator
19 SELECT setval('device_inbox_sequence', (
20 SELECT GREATEST(
21 (SELECT COALESCE(MAX(stream_id), 1) FROM device_inbox),
22 (SELECT COALESCE(MAX(stream_id), 1) FROM device_federation_outbox)
23 )
24 ));
0 /* Copyright 2020 The Matrix.org Foundation C.I.C
1 *
2 * Licensed under the Apache License, Version 2.0 (the "License");
3 * you may not use this file except in compliance with the License.
4 * You may obtain a copy of the License at
5 *
6 * http://www.apache.org/licenses/LICENSE-2.0
7 *
8 * Unless required by applicable law or agreed to in writing, software
9 * distributed under the License is distributed on an "AS IS" BASIS,
10 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
11 * See the License for the specific language governing permissions and
12 * limitations under the License.
13 */
14
15 -- See docs/auth_chain_difference_algorithm.md
16
17 CREATE TABLE event_auth_chains (
18 event_id TEXT PRIMARY KEY,
19 chain_id BIGINT NOT NULL,
20 sequence_number BIGINT NOT NULL
21 );
22
23 CREATE UNIQUE INDEX event_auth_chains_c_seq_index ON event_auth_chains (chain_id, sequence_number);
24
25
26 CREATE TABLE event_auth_chain_links (
27 origin_chain_id BIGINT NOT NULL,
28 origin_sequence_number BIGINT NOT NULL,
29
30 target_chain_id BIGINT NOT NULL,
31 target_sequence_number BIGINT NOT NULL
32 );
33
34
35 CREATE INDEX event_auth_chain_links_idx ON event_auth_chain_links (origin_chain_id, target_chain_id);
36
37
38 -- Events that we have persisted but not calculated auth chains for,
39 -- e.g. out of band memberships (where we don't have the auth chain)
40 CREATE TABLE event_auth_chain_to_calculate (
41 event_id TEXT PRIMARY KEY,
42 room_id TEXT NOT NULL,
43 type TEXT NOT NULL,
44 state_key TEXT NOT NULL
45 );
46
47 CREATE INDEX event_auth_chain_to_calculate_rm_id ON event_auth_chain_to_calculate(room_id);
48
49
50 -- Whether we've calculated the above index for a room.
51 ALTER TABLE rooms ADD COLUMN has_auth_chain_index BOOLEAN;
0 /* Copyright 2020 The Matrix.org Foundation C.I.C
1 *
2 * Licensed under the Apache License, Version 2.0 (the "License");
3 * you may not use this file except in compliance with the License.
4 * You may obtain a copy of the License at
5 *
6 * http://www.apache.org/licenses/LICENSE-2.0
7 *
8 * Unless required by applicable law or agreed to in writing, software
9 * distributed under the License is distributed on an "AS IS" BASIS,
10 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
11 * See the License for the specific language governing permissions and
12 * limitations under the License.
13 */
14
15 CREATE SEQUENCE IF NOT EXISTS event_auth_chain_id;
0 /* Copyright 2021 The Matrix.org Foundation C.I.C
1 *
2 * Licensed under the Apache License, Version 2.0 (the "License");
3 * you may not use this file except in compliance with the License.
4 * You may obtain a copy of the License at
5 *
6 * http://www.apache.org/licenses/LICENSE-2.0
7 *
8 * Unless required by applicable law or agreed to in writing, software
9 * distributed under the License is distributed on an "AS IS" BASIS,
10 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
11 * See the License for the specific language governing permissions and
12 * limitations under the License.
13 */
14
15 -- This is no longer used and was only kept until we bumped the schema version.
16 DROP TABLE IF EXISTS account_data_max_stream_id;
0 /* Copyright 2021 The Matrix.org Foundation C.I.C
1 *
2 * Licensed under the Apache License, Version 2.0 (the "License");
3 * you may not use this file except in compliance with the License.
4 * You may obtain a copy of the License at
5 *
6 * http://www.apache.org/licenses/LICENSE-2.0
7 *
8 * Unless required by applicable law or agreed to in writing, software
9 * distributed under the License is distributed on an "AS IS" BASIS,
10 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
11 * See the License for the specific language governing permissions and
12 * limitations under the License.
13 */
14
15 -- This is no longer used and was only kept until we bumped the schema version.
16 DROP TABLE IF EXISTS cache_invalidation_stream;
0 /* Copyright 2020 The Matrix.org Foundation C.I.C
1 *
2 * Licensed under the Apache License, Version 2.0 (the "License");
3 * you may not use this file except in compliance with the License.
4 * You may obtain a copy of the License at
5 *
6 * http://www.apache.org/licenses/LICENSE-2.0
7 *
8 * Unless required by applicable law or agreed to in writing, software
9 * distributed under the License is distributed on an "AS IS" BASIS,
10 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
11 * See the License for the specific language governing permissions and
12 * limitations under the License.
13 */
14
15 INSERT INTO background_updates (ordering, update_name, progress_json, depends_on) VALUES
16 (5906, 'chain_cover', '{}', 'rejected_events_metadata');
0 /* Copyright 2021 The Matrix.org Foundation C.I.C
1 *
2 * Licensed under the Apache License, Version 2.0 (the "License");
3 * you may not use this file except in compliance with the License.
4 * You may obtain a copy of the License at
5 *
6 * http://www.apache.org/licenses/LICENSE-2.0
7 *
8 * Unless required by applicable law or agreed to in writing, software
9 * distributed under the License is distributed on an "AS IS" BASIS,
10 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
11 * See the License for the specific language governing permissions and
12 * limitations under the License.
13 */
14
15 ALTER TABLE room_account_data ADD COLUMN instance_name TEXT;
16 ALTER TABLE room_tags_revisions ADD COLUMN instance_name TEXT;
17 ALTER TABLE account_data ADD COLUMN instance_name TEXT;
18
19 ALTER TABLE receipts_linearized ADD COLUMN instance_name TEXT;
0 /* Copyright 2021 The Matrix.org Foundation C.I.C
1 *
2 * Licensed under the Apache License, Version 2.0 (the "License");
3 * you may not use this file except in compliance with the License.
4 * You may obtain a copy of the License at
5 *
6 * http://www.apache.org/licenses/LICENSE-2.0
7 *
8 * Unless required by applicable law or agreed to in writing, software
9 * distributed under the License is distributed on an "AS IS" BASIS,
10 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
11 * See the License for the specific language governing permissions and
12 * limitations under the License.
13 */
14
15 CREATE SEQUENCE IF NOT EXISTS account_data_sequence;
16
17 -- We need to take the max across all the account_data tables as they share the
18 -- ID generator
19 SELECT setval('account_data_sequence', (
20 SELECT GREATEST(
21 (SELECT COALESCE(MAX(stream_id), 1) FROM room_account_data),
22 (SELECT COALESCE(MAX(stream_id), 1) FROM room_tags_revisions),
23 (SELECT COALESCE(MAX(stream_id), 1) FROM account_data)
24 )
25 ));
26
27 CREATE SEQUENCE IF NOT EXISTS receipts_sequence;
28
29 SELECT setval('receipts_sequence', (
30 SELECT COALESCE(MAX(stream_id), 1) FROM receipts_linearized
31 ));
0 /* Copyright 2021 The Matrix.org Foundation C.I.C
1 *
2 * Licensed under the Apache License, Version 2.0 (the "License");
3 * you may not use this file except in compliance with the License.
4 * You may obtain a copy of the License at
5 *
6 * http://www.apache.org/licenses/LICENSE-2.0
7 *
8 * Unless required by applicable law or agreed to in writing, software
9 * distributed under the License is distributed on an "AS IS" BASIS,
10 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
11 * See the License for the specific language governing permissions and
12 * limitations under the License.
13 */
14
15 -- We incorrectly populated these, so we delete them and let the
16 -- MultiWriterIdGenerator repopulate it.
17 DELETE FROM stream_positions WHERE stream_name = 'receipts' OR stream_name = 'account_data';
182182 )
183183 return {row["tag"]: db_to_json(row["content"]) for row in rows}
184184
185
186 class TagsStore(TagsWorkerStore):
187185 async def add_tag_to_room(
188186 self, user_id: str, room_id: str, tag: str, content: JsonDict
189187 ) -> int:
198196 Returns:
199197 The next account data ID.
200198 """
199 assert self._can_write_to_account_data
200
201201 content_json = json_encoder.encode(content)
202202
203203 def add_tag_txn(txn, next_id):
222222 Returns:
223223 The next account data ID.
224224 """
225 assert self._can_write_to_account_data
225226
226227 def remove_tag_txn(txn, next_id):
227228 sql = (
249250 room_id: The ID of the room.
250251 next_id: The the revision to advance to.
251252 """
253 assert self._can_write_to_account_data
252254
253255 txn.call_after(
254256 self._account_data_stream_cache.entity_has_changed, user_id, next_id
255257 )
256
257 # Note: This is only here for backwards compat to allow admins to
258 # roll back to a previous Synapse version. Next time we update the
259 # database version we can remove this table.
260 update_max_id_sql = (
261 "UPDATE account_data_max_stream_id"
262 " SET stream_id = ?"
263 " WHERE stream_id < ?"
264 )
265 txn.execute(update_max_id_sql, (next_id, next_id))
266258
267259 update_sql = (
268260 "UPDATE room_tags_revisions"
287279 # which stream_id ends up in the table, as long as it is higher
288280 # than the id that the client has.
289281 pass
282
283
284 class TagsStore(TagsWorkerStore):
285 pass
463463 txn: LoggingTransaction, now_time_ms: int, after_destination: Optional[str]
464464 ) -> List[str]:
465465 q = """
466 SELECT destination FROM destinations
467 WHERE destination IN (
468 SELECT destination FROM destination_rooms
469 WHERE destination_rooms.stream_ordering >
470 destinations.last_successful_stream_ordering
471 )
472 AND destination > ?
473 AND (
474 retry_last_ts IS NULL OR
475 retry_last_ts + retry_interval < ?
476 )
477 ORDER BY destination
478 LIMIT 25
466 SELECT DISTINCT destination FROM destinations
467 INNER JOIN destination_rooms USING (destination)
468 WHERE
469 stream_ordering > last_successful_stream_ordering
470 AND destination > ?
471 AND (
472 retry_last_ts IS NULL OR
473 retry_last_ts + retry_interval < ?
474 )
475 ORDER BY destination
476 LIMIT 25
479477 """
480478 txn.execute(
481479 q,
3434
3535 # Remember to update this number every time a change is made to database
3636 # schema files, so the users will be informed on server restarts.
37 # XXX: If you're about to bump this to 59 (or higher) please create an update
38 # that drops the unused `cache_invalidation_stream` table, as per #7436!
39 # XXX: Also add an update to drop `account_data_max_stream_id` as per #7656!
40 SCHEMA_VERSION = 58
37 SCHEMA_VERSION = 59
4138
4239 dir_path = os.path.abspath(os.path.dirname(__file__))
4340
374371 specific_engine_extensions = (".sqlite", ".postgres")
375372
376373 for v in range(start_ver, SCHEMA_VERSION + 1):
377 logger.info("Applying schema deltas for v%d", v)
374 if not is_worker:
375 logger.info("Applying schema deltas for v%d", v)
376
377 cur.execute("DELETE FROM schema_version")
378 cur.execute(
379 "INSERT INTO schema_version (version, upgraded) VALUES (?,?)",
380 (v, True),
381 )
382 else:
383 logger.info("Checking schema deltas for v%d", v)
378384
379385 # We need to search both the global and per data store schema
380386 # directories for schema updates.
488494 (v, relative_path),
489495 )
490496
491 cur.execute("DELETE FROM schema_version")
492 cur.execute(
493 "INSERT INTO schema_version (version, upgraded) VALUES (?,?)",
494 (v, True),
495 )
496
497497 logger.info("Schema now up to date")
498498
499499
1616 import threading
1717 from collections import deque
1818 from contextlib import contextmanager
19 from typing import Dict, List, Optional, Set, Union
19 from typing import Dict, List, Optional, Set, Tuple, Union
2020
2121 import attr
2222 from typing_extensions import Deque
185185 Args:
186186 db_conn
187187 db
188 stream_name: A name for the stream.
188 stream_name: A name for the stream, for use in the `stream_positions`
189 table. (Does not need to be the same as the replication stream name)
189190 instance_name: The name of this instance.
190 table: Database table associated with stream.
191 instance_column: Column that stores the row's writer's instance name
192 id_column: Column that stores the stream ID.
191 tables: List of tables associated with the stream. Tuple of table
192 name, column name that stores the writer's instance name, and
193 column name that stores the stream ID.
193194 sequence_name: The name of the postgres sequence used to generate new
194195 IDs.
195196 writers: A list of known writers to use to populate current positions
205206 db: DatabasePool,
206207 stream_name: str,
207208 instance_name: str,
208 table: str,
209 instance_column: str,
210 id_column: str,
209 tables: List[Tuple[str, str, str]],
211210 sequence_name: str,
212211 writers: List[str],
213212 positive: bool = True,
259258 self._sequence_gen = PostgresSequenceGenerator(sequence_name)
260259
261260 # We check that the table and sequence haven't diverged.
262 self._sequence_gen.check_consistency(
263 db_conn, table=table, id_column=id_column, positive=positive
264 )
261 for table, _, id_column in tables:
262 self._sequence_gen.check_consistency(
263 db_conn,
264 table=table,
265 id_column=id_column,
266 stream_name=stream_name,
267 positive=positive,
268 )
265269
266270 # This goes and fills out the above state from the database.
267 self._load_current_ids(db_conn, table, instance_column, id_column)
271 self._load_current_ids(db_conn, tables)
268272
269273 def _load_current_ids(
270 self, db_conn, table: str, instance_column: str, id_column: str
274 self, db_conn, tables: List[Tuple[str, str, str]],
271275 ):
272276 cur = db_conn.cursor(txn_name="_load_current_ids")
273277
305309 # We add a GREATEST here to ensure that the result is always
306310 # positive. (This can be a problem for e.g. backfill streams where
307311 # the server has never backfilled).
308 sql = """
309 SELECT GREATEST(COALESCE(%(agg)s(%(id)s), 1), 1)
310 FROM %(table)s
311 """ % {
312 "id": id_column,
313 "table": table,
314 "agg": "MAX" if self._positive else "-MIN",
315 }
316 cur.execute(sql)
317 (stream_id,) = cur.fetchone()
318 self._persisted_upto_position = stream_id
312 max_stream_id = 1
313 for table, _, id_column in tables:
314 sql = """
315 SELECT GREATEST(COALESCE(%(agg)s(%(id)s), 1), 1)
316 FROM %(table)s
317 """ % {
318 "id": id_column,
319 "table": table,
320 "agg": "MAX" if self._positive else "-MIN",
321 }
322 cur.execute(sql)
323 (stream_id,) = cur.fetchone()
324
325 max_stream_id = max(max_stream_id, stream_id)
326
327 self._persisted_upto_position = max_stream_id
319328 else:
320329 # If we have a min_stream_id then we pull out everything greater
321330 # than it from the DB so that we can prefill
328337 # stream positions table before restart (or the stream position
329338 # table otherwise got out of date).
330339
331 sql = """
332 SELECT %(instance)s, %(id)s FROM %(table)s
333 WHERE ? %(cmp)s %(id)s
334 """ % {
335 "id": id_column,
336 "table": table,
337 "instance": instance_column,
338 "cmp": "<=" if self._positive else ">=",
339 }
340 cur.execute(sql, (min_stream_id * self._return_factor,))
341
342340 self._persisted_upto_position = min_stream_id
343341
342 rows = []
343 for table, instance_column, id_column in tables:
344 sql = """
345 SELECT %(instance)s, %(id)s FROM %(table)s
346 WHERE ? %(cmp)s %(id)s
347 """ % {
348 "id": id_column,
349 "table": table,
350 "instance": instance_column,
351 "cmp": "<=" if self._positive else ">=",
352 }
353 cur.execute(sql, (min_stream_id * self._return_factor,))
354
355 rows.extend(cur)
356
357 # Sort so that we handle rows in order for each instance.
358 rows.sort()
359
344360 with self._lock:
345 for (instance, stream_id,) in cur:
361 for (instance, stream_id,) in rows:
346362 stream_id = self._return_factor * stream_id
347363 self._add_persisted_position(stream_id)
348364
1414 import abc
1515 import logging
1616 import threading
17 from typing import Callable, List, Optional
18
19 from synapse.storage.database import LoggingDatabaseConnection
17 from typing import TYPE_CHECKING, Callable, List, Optional
18
2019 from synapse.storage.engines import (
2120 BaseDatabaseEngine,
2221 IncorrectDatabaseSetup,
2423 )
2524 from synapse.storage.types import Connection, Cursor
2625
26 if TYPE_CHECKING:
27 from synapse.storage.database import LoggingDatabaseConnection
28
2729 logger = logging.getLogger(__name__)
2830
2931
4244 See docs/postgres.md for more information.
4345 """
4446
47 _INCONSISTENT_STREAM_ERROR = """
48 Postgres sequence '%(seq)s' is inconsistent with associated stream position
49 of '%(stream_name)s' in the 'stream_positions' table.
50
51 This is likely a programming error and should be reported at
52 https://github.com/matrix-org/synapse.
53
54 A temporary workaround to fix this error is to shut down Synapse (including
55 any and all workers) and run the following SQL:
56
57 DELETE FROM stream_positions WHERE stream_name = '%(stream_name)s';
58
59 This will need to be done every time the server is restarted.
60 """
61
4562
4663 class SequenceGenerator(metaclass=abc.ABCMeta):
4764 """A class which generates a unique sequence of integers"""
5471 @abc.abstractmethod
5572 def check_consistency(
5673 self,
57 db_conn: LoggingDatabaseConnection,
74 db_conn: "LoggingDatabaseConnection",
5875 table: str,
5976 id_column: str,
77 stream_name: Optional[str] = None,
6078 positive: bool = True,
6179 ):
6280 """Should be called during start up to test that the current value of
6381 the sequence is greater than or equal to the maximum ID in the table.
6482
65 This is to handle various cases where the sequence value can get out
66 of sync with the table, e.g. if Synapse gets rolled back to a previous
83 This is to handle various cases where the sequence value can get out of
84 sync with the table, e.g. if Synapse gets rolled back to a previous
6785 version and the rolled forwards again.
86
87 If a stream name is given then this will check that any value in the
88 `stream_positions` table is less than or equal to the current sequence
89 value. If it isn't then it's likely that streams have been crossed
90 somewhere (e.g. two ID generators have the same stream name).
6891 """
6992 ...
7093
87110
88111 def check_consistency(
89112 self,
90 db_conn: LoggingDatabaseConnection,
113 db_conn: "LoggingDatabaseConnection",
91114 table: str,
92115 id_column: str,
116 stream_name: Optional[str] = None,
93117 positive: bool = True,
94118 ):
119 """See SequenceGenerator.check_consistency for docstring.
120 """
121
95122 txn = db_conn.cursor(txn_name="sequence.check_consistency")
96123
97124 # First we get the current max ID from the table.
115142 "SELECT last_value, is_called FROM %(seq)s" % {"seq": self._sequence_name}
116143 )
117144 last_value, is_called = txn.fetchone()
145
146 # If we have an associated stream check the stream_positions table.
147 max_in_stream_positions = None
148 if stream_name:
149 txn.execute(
150 "SELECT MAX(stream_id) FROM stream_positions WHERE stream_name = ?",
151 (stream_name,),
152 )
153 row = txn.fetchone()
154 if row:
155 max_in_stream_positions = row[0]
156
118157 txn.close()
119158
120159 # If `is_called` is False then `last_value` is actually the value that
133172 raise IncorrectDatabaseSetup(
134173 _INCONSISTENT_SEQUENCE_ERROR
135174 % {"seq": self._sequence_name, "table": table, "max_id_sql": table_sql}
175 )
176
177 # If we have values in the stream positions table then they have to be
178 # less than or equal to `last_value`
179 if max_in_stream_positions and max_in_stream_positions > last_value:
180 raise IncorrectDatabaseSetup(
181 _INCONSISTENT_STREAM_ERROR
182 % {"seq": self._sequence_name, "stream_name": stream_name}
136183 )
137184
138185
172219 return self._current_max_id
173220
174221 def check_consistency(
175 self, db_conn: Connection, table: str, id_column: str, positive: bool = True
222 self,
223 db_conn: Connection,
224 table: str,
225 id_column: str,
226 stream_name: Optional[str] = None,
227 positive: bool = True,
176228 ):
177229 # There is nothing to do for in memory sequences
178230 pass
3636 from unpaddedbase64 import decode_base64
3737
3838 from synapse.api.errors import Codes, SynapseError
39 from synapse.util.stringutils import parse_and_validate_server_name
3940
4041 if TYPE_CHECKING:
4142 from synapse.appservice.api import ApplicationService
256257
257258 @classmethod
258259 def is_valid(cls: Type[DS], s: str) -> bool:
260 """Parses the input string and attempts to ensure it is valid."""
259261 try:
260 cls.from_string(s)
262 obj = cls.from_string(s)
263 # Apply additional validation to the domain. This is only done
264 # during is_valid (and not part of from_string) since it is
265 # possible for invalid data to exist in room-state, etc.
266 parse_and_validate_server_name(obj.domain)
261267 return True
262268 except Exception:
263269 return False
104104 keylen=keylen,
105105 cache_name=name,
106106 cache_type=cache_type,
107 size_callback=(lambda d: len(d)) if iterable else None,
107 size_callback=(lambda d: len(d) or 1) if iterable else None,
108108 metrics_collection_callback=metrics_cb,
109109 apply_cache_factor_from_config=apply_cache_factor_from_config,
110110 ) # type: LruCache[KT, VT]
1212 # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
1313 # See the License for the specific language governing permissions and
1414 # limitations under the License.
15 import heapq
1516 from itertools import islice
16 from typing import Iterable, Iterator, Sequence, Tuple, TypeVar
17 from typing import (
18 Dict,
19 Generator,
20 Iterable,
21 Iterator,
22 Mapping,
23 Sequence,
24 Set,
25 Tuple,
26 TypeVar,
27 )
28
29 from synapse.types import Collection
1730
1831 T = TypeVar("T")
1932
4558 If the input is empty, no chunks are returned.
4659 """
4760 return (iseq[i : i + maxlen] for i in range(0, len(iseq), maxlen))
61
62
63 def sorted_topologically(
64 nodes: Iterable[T], graph: Mapping[T, Collection[T]],
65 ) -> Generator[T, None, None]:
66 """Given a set of nodes and a graph, yield the nodes in toplogical order.
67
68 For example `sorted_topologically([1, 2], {1: [2]})` will yield `2, 1`.
69 """
70
71 # This is implemented by Kahn's algorithm.
72
73 degree_map = {node: 0 for node in nodes}
74 reverse_graph = {} # type: Dict[T, Set[T]]
75
76 for node, edges in graph.items():
77 if node not in degree_map:
78 continue
79
80 for edge in set(edges):
81 if edge in degree_map:
82 degree_map[node] += 1
83
84 reverse_graph.setdefault(edge, set()).add(node)
85 reverse_graph.setdefault(node, set())
86
87 zero_degree = [node for node, degree in degree_map.items() if degree == 0]
88 heapq.heapify(zero_degree)
89
90 while zero_degree:
91 node = heapq.heappop(zero_degree)
92 yield node
93
94 for edge in reverse_graph.get(node, []):
95 if edge in degree_map:
96 degree_map[edge] -= 1
97 if degree_map[edge] == 0:
98 heapq.heappush(zero_degree, edge)
107107 def __init__(self, clock, name):
108108 self.clock = clock
109109 self.name = name
110 parent_context = current_context()
110 curr_context = current_context()
111 if not curr_context:
112 logger.warning(
113 "Starting metrics collection %r from sentinel context: metrics will be lost",
114 name,
115 )
116 parent_context = None
117 else:
118 assert isinstance(curr_context, LoggingContext)
119 parent_context = curr_context
111120 self._logging_context = LoggingContext(
112121 "Measure[%s]" % (self.name,), parent_context
113122 )
1717 import re
1818 import string
1919 from collections.abc import Iterable
20 from typing import Optional, Tuple
2021
2122 from synapse.api.errors import Codes, SynapseError
2223
2425
2526 # https://matrix.org/docs/spec/client_server/r0.6.0#post-matrix-client-r0-register-email-requesttoken
2627 client_secret_regex = re.compile(r"^[0-9a-zA-Z\.\=\_\-]+$")
28
29 # https://matrix.org/docs/spec/client_server/r0.6.1#matrix-content-mxc-uris,
30 # together with https://github.com/matrix-org/matrix-doc/issues/2177 which basically
31 # says "there is no grammar for media ids"
32 #
33 # The server_name part of this is purposely lax: use parse_and_validate_mxc for
34 # additional validation.
35 #
36 MXC_REGEX = re.compile("^mxc://([^/]+)/([^/#?]+)$")
2737
2838 # random_string and random_string_with_symbols are used for a range of things,
2939 # some cryptographically important, some less so. We use SystemRandom to make sure
5868 )
5969
6070
71 def parse_server_name(server_name: str) -> Tuple[str, Optional[int]]:
72 """Split a server name into host/port parts.
73
74 Args:
75 server_name: server name to parse
76
77 Returns:
78 host/port parts.
79
80 Raises:
81 ValueError if the server name could not be parsed.
82 """
83 try:
84 if server_name[-1] == "]":
85 # ipv6 literal, hopefully
86 return server_name, None
87
88 domain_port = server_name.rsplit(":", 1)
89 domain = domain_port[0]
90 port = int(domain_port[1]) if domain_port[1:] else None
91 return domain, port
92 except Exception:
93 raise ValueError("Invalid server name '%s'" % server_name)
94
95
96 VALID_HOST_REGEX = re.compile("\\A[0-9a-zA-Z.-]+\\Z")
97
98
99 def parse_and_validate_server_name(server_name: str) -> Tuple[str, Optional[int]]:
100 """Split a server name into host/port parts and do some basic validation.
101
102 Args:
103 server_name: server name to parse
104
105 Returns:
106 host/port parts.
107
108 Raises:
109 ValueError if the server name could not be parsed.
110 """
111 host, port = parse_server_name(server_name)
112
113 # these tests don't need to be bulletproof as we'll find out soon enough
114 # if somebody is giving us invalid data. What we *do* need is to be sure
115 # that nobody is sneaking IP literals in that look like hostnames, etc.
116
117 # look for ipv6 literals
118 if host[0] == "[":
119 if host[-1] != "]":
120 raise ValueError("Mismatched [...] in server name '%s'" % (server_name,))
121 return host, port
122
123 # otherwise it should only be alphanumerics.
124 if not VALID_HOST_REGEX.match(host):
125 raise ValueError(
126 "Server name '%s' contains invalid characters" % (server_name,)
127 )
128
129 return host, port
130
131
132 def parse_and_validate_mxc_uri(mxc: str) -> Tuple[str, Optional[int], str]:
133 """Parse the given string as an MXC URI
134
135 Checks that the "server name" part is a valid server name
136
137 Args:
138 mxc: the (alleged) MXC URI to be checked
139 Returns:
140 hostname, port, media id
141 Raises:
142 ValueError if the URI cannot be parsed
143 """
144 m = MXC_REGEX.match(mxc)
145 if not m:
146 raise ValueError("mxc URI %r did not match expected format" % (mxc,))
147 server_name = m.group(1)
148 media_id = m.group(2)
149 host, port = parse_and_validate_server_name(server_name)
150 return host, port, media_id
151
152
61153 def shortstr(iterable: Iterable, maxitems: int = 5) -> str:
62154 """If iterable has maxitems or fewer, return the stringification of a list
63155 containing those items.
74166 if len(items) <= maxitems:
75167 return str(items)
76168 return "[" + ", ".join(repr(r) for r in items[:maxitems]) + ", ...]"
169
170
171 def strtobool(val: str) -> bool:
172 """Convert a string representation of truth to True or False
173
174 True values are 'y', 'yes', 't', 'true', 'on', and '1'; false values
175 are 'n', 'no', 'f', 'false', 'off', and '0'. Raises ValueError if
176 'val' is anything else.
177
178 This is lifted from distutils.util.strtobool, with the exception that it actually
179 returns a bool, rather than an int.
180 """
181 val = val.lower()
182 if val in ("y", "yes", "t", "true", "on", "1"):
183 return True
184 elif val in ("n", "no", "f", "false", "off", "0"):
185 return False
186 else:
187 raise ValueError("invalid truth value %r" % (val,))
0 # -*- coding: utf-8 -*-
1 # Copyright 2021 The Matrix.org Foundation C.I.C.
2 #
3 # Licensed under the Apache License, Version 2.0 (the "License");
4 # you may not use this file except in compliance with the License.
5 # You may obtain a copy of the License at
6 #
7 # http://www.apache.org/licenses/LICENSE-2.0
8 #
9 # Unless required by applicable law or agreed to in writing, software
10 # distributed under the License is distributed on an "AS IS" BASIS,
11 # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 # See the License for the specific language governing permissions and
13 # limitations under the License.
14
15 from synapse.config import ConfigError
16 from synapse.config._util import validate_config
17
18 from tests.unittest import TestCase
19
20
21 class ValidateConfigTestCase(TestCase):
22 """Test cases for synapse.config._util.validate_config"""
23
24 def test_bad_object_in_array(self):
25 """malformed objects within an array should be validated correctly"""
26
27 # consider a structure:
28 #
29 # array_of_objs:
30 # - r: 1
31 # foo: 2
32 #
33 # - r: 2
34 # bar: 3
35 #
36 # ... where each entry must contain an "r": check that the path
37 # to the required item is correclty reported.
38
39 schema = {
40 "type": "object",
41 "properties": {
42 "array_of_objs": {
43 "type": "array",
44 "items": {"type": "object", "required": ["r"]},
45 },
46 },
47 }
48
49 with self.assertRaises(ConfigError) as c:
50 validate_config(schema, {"array_of_objs": [{}]}, ("base",))
51
52 self.assertEqual(c.exception.path, ["base", "array_of_objs", "<item 0>"])
3333
3434
3535 class PruneEventTestCase(unittest.TestCase):
36 """ Asserts that a new event constructed with `evdict` will look like
37 `matchdict` when it is redacted. """
38
3936 def run_test(self, evdict, matchdict, **kwargs):
40 self.assertEquals(
37 """
38 Asserts that a new event constructed with `evdict` will look like
39 `matchdict` when it is redacted.
40
41 Args:
42 evdict: The dictionary to build the event from.
43 matchdict: The expected resulting dictionary.
44 kwargs: Additional keyword arguments used to create the event.
45 """
46 self.assertEqual(
4147 prune_event(make_event_from_dict(evdict, **kwargs)).get_dict(), matchdict
4248 )
4349
5460 )
5561
5662 def test_basic_keys(self):
57 self.run_test(
58 {
63 """Ensure that the keys that should be untouched are kept."""
64 # Note that some of the values below don't really make sense, but the
65 # pruning of events doesn't worry about the values of any fields (with
66 # the exception of the content field).
67 self.run_test(
68 {
69 "event_id": "$3:domain",
5970 "type": "A",
6071 "room_id": "!1:domain",
6172 "sender": "@2:domain",
73 "state_key": "B",
74 "content": {"other_key": "foo"},
75 "hashes": "hashes",
76 "signatures": {"domain": {"algo:1": "sigs"}},
77 "depth": 4,
78 "prev_events": "prev_events",
79 "prev_state": "prev_state",
80 "auth_events": "auth_events",
81 "origin": "domain",
82 "origin_server_ts": 1234,
83 "membership": "join",
84 # Also include a key that should be removed.
85 "other_key": "foo",
86 },
87 {
6288 "event_id": "$3:domain",
63 "origin": "domain",
64 },
65 {
6689 "type": "A",
6790 "room_id": "!1:domain",
6891 "sender": "@2:domain",
69 "event_id": "$3:domain",
92 "state_key": "B",
93 "hashes": "hashes",
94 "depth": 4,
95 "prev_events": "prev_events",
96 "prev_state": "prev_state",
97 "auth_events": "auth_events",
7098 "origin": "domain",
71 "content": {},
72 "signatures": {},
73 "unsigned": {},
74 },
75 )
76
77 def test_unsigned_age_ts(self):
78 self.run_test(
79 {"type": "B", "event_id": "$test:domain", "unsigned": {"age_ts": 20}},
99 "origin_server_ts": 1234,
100 "membership": "join",
101 "content": {},
102 "signatures": {"domain": {"algo:1": "sigs"}},
103 "unsigned": {},
104 },
105 )
106
107 # As of MSC2176 we now redact the membership and prev_states keys.
108 self.run_test(
109 {"type": "A", "prev_state": "prev_state", "membership": "join"},
110 {"type": "A", "content": {}, "signatures": {}, "unsigned": {}},
111 room_version=RoomVersions.MSC2176,
112 )
113
114 def test_unsigned(self):
115 """Ensure that unsigned properties get stripped (except age_ts and replaces_state)."""
116 self.run_test(
80117 {
81118 "type": "B",
82119 "event_id": "$test:domain",
83 "content": {},
84 "signatures": {},
85 "unsigned": {"age_ts": 20},
86 },
87 )
88
89 self.run_test(
120 "unsigned": {
121 "age_ts": 20,
122 "replaces_state": "$test2:domain",
123 "other_key": "foo",
124 },
125 },
90126 {
91127 "type": "B",
92128 "event_id": "$test:domain",
93 "unsigned": {"other_key": "here"},
94 },
95 {
96 "type": "B",
97 "event_id": "$test:domain",
98 "content": {},
99 "signatures": {},
100 "unsigned": {},
129 "content": {},
130 "signatures": {},
131 "unsigned": {"age_ts": 20, "replaces_state": "$test2:domain"},
101132 },
102133 )
103134
104135 def test_content(self):
136 """The content dictionary should be stripped in most cases."""
105137 self.run_test(
106138 {"type": "C", "event_id": "$test:domain", "content": {"things": "here"}},
107139 {
113145 },
114146 )
115147
148 # Some events keep a single content key/value.
149 EVENT_KEEP_CONTENT_KEYS = [
150 ("member", "membership", "join"),
151 ("join_rules", "join_rule", "invite"),
152 ("history_visibility", "history_visibility", "shared"),
153 ]
154 for event_type, key, value in EVENT_KEEP_CONTENT_KEYS:
155 self.run_test(
156 {
157 "type": "m.room." + event_type,
158 "event_id": "$test:domain",
159 "content": {key: value, "other_key": "foo"},
160 },
161 {
162 "type": "m.room." + event_type,
163 "event_id": "$test:domain",
164 "content": {key: value},
165 "signatures": {},
166 "unsigned": {},
167 },
168 )
169
170 def test_create(self):
171 """Create events are partially redacted until MSC2176."""
116172 self.run_test(
117173 {
118174 "type": "m.room.create",
119175 "event_id": "$test:domain",
120 "content": {"creator": "@2:domain", "other_field": "here"},
176 "content": {"creator": "@2:domain", "other_key": "foo"},
121177 },
122178 {
123179 "type": "m.room.create",
126182 "signatures": {},
127183 "unsigned": {},
128184 },
185 )
186
187 # After MSC2176, create events get nothing redacted.
188 self.run_test(
189 {"type": "m.room.create", "content": {"not_a_real_key": True}},
190 {
191 "type": "m.room.create",
192 "content": {"not_a_real_key": True},
193 "signatures": {},
194 "unsigned": {},
195 },
196 room_version=RoomVersions.MSC2176,
197 )
198
199 def test_power_levels(self):
200 """Power level events keep a variety of content keys."""
201 self.run_test(
202 {
203 "type": "m.room.power_levels",
204 "event_id": "$test:domain",
205 "content": {
206 "ban": 1,
207 "events": {"m.room.name": 100},
208 "events_default": 2,
209 "invite": 3,
210 "kick": 4,
211 "redact": 5,
212 "state_default": 6,
213 "users": {"@admin:domain": 100},
214 "users_default": 7,
215 "other_key": 8,
216 },
217 },
218 {
219 "type": "m.room.power_levels",
220 "event_id": "$test:domain",
221 "content": {
222 "ban": 1,
223 "events": {"m.room.name": 100},
224 "events_default": 2,
225 # Note that invite is not here.
226 "kick": 4,
227 "redact": 5,
228 "state_default": 6,
229 "users": {"@admin:domain": 100},
230 "users_default": 7,
231 },
232 "signatures": {},
233 "unsigned": {},
234 },
235 )
236
237 # After MSC2176, power levels events keep the invite key.
238 self.run_test(
239 {"type": "m.room.power_levels", "content": {"invite": 75}},
240 {
241 "type": "m.room.power_levels",
242 "content": {"invite": 75},
243 "signatures": {},
244 "unsigned": {},
245 },
246 room_version=RoomVersions.MSC2176,
129247 )
130248
131249 def test_alias_event(self):
145263 },
146264 )
147265
148 def test_msc2432_alias_event(self):
149 """After MSC2432, alias events have no special behavior."""
266 # After MSC2432, alias events have no special behavior.
150267 self.run_test(
151268 {"type": "m.room.aliases", "content": {"aliases": ["test"]}},
152269 {
156273 "unsigned": {},
157274 },
158275 room_version=RoomVersions.V6,
276 )
277
278 def test_redacts(self):
279 """Redaction events have no special behaviour until MSC2174/MSC2176."""
280
281 self.run_test(
282 {"type": "m.room.redaction", "content": {"redacts": "$test2:domain"}},
283 {
284 "type": "m.room.redaction",
285 "content": {},
286 "signatures": {},
287 "unsigned": {},
288 },
289 room_version=RoomVersions.V6,
290 )
291
292 # After MSC2174, redaction events keep the redacts content key.
293 self.run_test(
294 {"type": "m.room.redaction", "content": {"redacts": "$test2:domain"}},
295 {
296 "type": "m.room.redaction",
297 "content": {"redacts": "$test2:domain"},
298 "signatures": {},
299 "unsigned": {},
300 },
301 room_version=RoomVersions.MSC2176,
159302 )
160303
161304
117117
118118 def _mock_request():
119119 """Returns a mock which will stand in as a SynapseRequest"""
120 return Mock(spec=["getClientIP", "get_user_agent"])
120 return Mock(spec=["getClientIP", "getHeader"])
1212 # See the License for the specific language governing permissions and
1313 # limitations under the License.
1414 import json
15 import re
16 from typing import Dict
17 from urllib.parse import parse_qs, urlencode, urlparse
15 from typing import Optional
16 from urllib.parse import parse_qs, urlparse
1817
1918 from mock import ANY, Mock, patch
2019
2120 import pymacaroons
2221
23 from twisted.web.resource import Resource
24
25 from synapse.api.errors import RedirectException
26 from synapse.handlers.oidc_handler import OidcError
2722 from synapse.handlers.sso import MappingException
28 from synapse.rest.client.v1 import login
29 from synapse.rest.synapse.client.pick_username import pick_username_resource
3023 from synapse.server import HomeServer
3124 from synapse.types import UserID
3225
3326 from tests.test_utils import FakeResponse, simple_async_mock
3427 from tests.unittest import HomeserverTestCase, override_config
28
29 try:
30 import authlib # noqa: F401
31
32 HAS_OIDC = True
33 except ImportError:
34 HAS_OIDC = False
35
3536
3637 # These are a few constants that are used as config parameters in the tests.
3738 ISSUER = "https://issuer/"
112113
113114
114115 class OidcHandlerTestCase(HomeserverTestCase):
116 if not HAS_OIDC:
117 skip = "requires OIDC"
118
115119 def default_config(self):
116120 config = super().default_config()
117121 config["public_baseurl"] = BASE_URL
140144 hs = self.setup_test_homeserver(proxied_http_client=self.http_client)
141145
142146 self.handler = hs.get_oidc_handler()
147 self.provider = self.handler._providers["oidc"]
143148 sso_handler = hs.get_sso_handler()
144149 # Mock the render error method.
145150 self.render_error = Mock(return_value=None)
151156 return hs
152157
153158 def metadata_edit(self, values):
154 return patch.dict(self.handler._provider_metadata, values)
159 return patch.dict(self.provider._provider_metadata, values)
155160
156161 def assertRenderedError(self, error, error_description=None):
162 self.render_error.assert_called_once()
157163 args = self.render_error.call_args[0]
158164 self.assertEqual(args[1], error)
159165 if error_description is not None:
164170
165171 def test_config(self):
166172 """Basic config correctly sets up the callback URL and client auth correctly."""
167 self.assertEqual(self.handler._callback_url, CALLBACK_URL)
168 self.assertEqual(self.handler._client_auth.client_id, CLIENT_ID)
169 self.assertEqual(self.handler._client_auth.client_secret, CLIENT_SECRET)
173 self.assertEqual(self.provider._callback_url, CALLBACK_URL)
174 self.assertEqual(self.provider._client_auth.client_id, CLIENT_ID)
175 self.assertEqual(self.provider._client_auth.client_secret, CLIENT_SECRET)
170176
171177 @override_config({"oidc_config": {"discover": True}})
172178 def test_discovery(self):
173179 """The handler should discover the endpoints from OIDC discovery document."""
174180 # This would throw if some metadata were invalid
175 metadata = self.get_success(self.handler.load_metadata())
181 metadata = self.get_success(self.provider.load_metadata())
176182 self.http_client.get_json.assert_called_once_with(WELL_KNOWN)
177183
178184 self.assertEqual(metadata.issuer, ISSUER)
184190
185191 # subsequent calls should be cached
186192 self.http_client.reset_mock()
187 self.get_success(self.handler.load_metadata())
193 self.get_success(self.provider.load_metadata())
188194 self.http_client.get_json.assert_not_called()
189195
190196 @override_config({"oidc_config": COMMON_CONFIG})
191197 def test_no_discovery(self):
192198 """When discovery is disabled, it should not try to load from discovery document."""
193 self.get_success(self.handler.load_metadata())
199 self.get_success(self.provider.load_metadata())
194200 self.http_client.get_json.assert_not_called()
195201
196202 @override_config({"oidc_config": COMMON_CONFIG})
197203 def test_load_jwks(self):
198204 """JWKS loading is done once (then cached) if used."""
199 jwks = self.get_success(self.handler.load_jwks())
205 jwks = self.get_success(self.provider.load_jwks())
200206 self.http_client.get_json.assert_called_once_with(JWKS_URI)
201207 self.assertEqual(jwks, {"keys": []})
202208
203209 # subsequent calls should be cached…
204210 self.http_client.reset_mock()
205 self.get_success(self.handler.load_jwks())
211 self.get_success(self.provider.load_jwks())
206212 self.http_client.get_json.assert_not_called()
207213
208214 # …unless forced
209215 self.http_client.reset_mock()
210 self.get_success(self.handler.load_jwks(force=True))
216 self.get_success(self.provider.load_jwks(force=True))
211217 self.http_client.get_json.assert_called_once_with(JWKS_URI)
212218
213219 # Throw if the JWKS uri is missing
214220 with self.metadata_edit({"jwks_uri": None}):
215 self.get_failure(self.handler.load_jwks(force=True), RuntimeError)
221 self.get_failure(self.provider.load_jwks(force=True), RuntimeError)
216222
217223 # Return empty key set if JWKS are not used
218 self.handler._scopes = [] # not asking the openid scope
224 self.provider._scopes = [] # not asking the openid scope
219225 self.http_client.get_json.reset_mock()
220 jwks = self.get_success(self.handler.load_jwks(force=True))
226 jwks = self.get_success(self.provider.load_jwks(force=True))
221227 self.http_client.get_json.assert_not_called()
222228 self.assertEqual(jwks, {"keys": []})
223229
224230 @override_config({"oidc_config": COMMON_CONFIG})
225231 def test_validate_config(self):
226232 """Provider metadatas are extensively validated."""
227 h = self.handler
233 h = self.provider
228234
229235 # Default test config does not throw
230236 h._validate_metadata()
303309 """Provider metadata validation can be disabled by config."""
304310 with self.metadata_edit({"issuer": "http://insecure"}):
305311 # This should not throw
306 self.handler._validate_metadata()
312 self.provider._validate_metadata()
307313
308314 def test_redirect_request(self):
309315 """The redirect request has the right arguments & generates a valid session cookie."""
310316 req = Mock(spec=["addCookie"])
311317 url = self.get_success(
312 self.handler.handle_redirect_request(req, b"http://client/redirect")
318 self.provider.handle_redirect_request(req, b"http://client/redirect")
313319 )
314320 url = urlparse(url)
315321 auth_endpoint = urlparse(AUTHORIZATION_ENDPOINT)
338344 cookie = args[1]
339345
340346 macaroon = pymacaroons.Macaroon.deserialize(cookie)
341 state = self.handler._get_value_from_macaroon(macaroon, "state")
342 nonce = self.handler._get_value_from_macaroon(macaroon, "nonce")
343 redirect = self.handler._get_value_from_macaroon(
347 state = self.handler._token_generator._get_value_from_macaroon(
348 macaroon, "state"
349 )
350 nonce = self.handler._token_generator._get_value_from_macaroon(
351 macaroon, "nonce"
352 )
353 redirect = self.handler._token_generator._get_value_from_macaroon(
344354 macaroon, "client_redirect_url"
345355 )
346356
373383
374384 # ensure that we are correctly testing the fallback when "get_extra_attributes"
375385 # is not implemented.
376 mapping_provider = self.handler._user_mapping_provider
386 mapping_provider = self.provider._user_mapping_provider
377387 with self.assertRaises(AttributeError):
378388 _ = mapping_provider.get_extra_attributes
379389
388398 "username": username,
389399 }
390400 expected_user_id = "@%s:%s" % (username, self.hs.hostname)
391 self.handler._exchange_code = simple_async_mock(return_value=token)
392 self.handler._parse_id_token = simple_async_mock(return_value=userinfo)
393 self.handler._fetch_userinfo = simple_async_mock(return_value=userinfo)
401 self.provider._exchange_code = simple_async_mock(return_value=token)
402 self.provider._parse_id_token = simple_async_mock(return_value=userinfo)
403 self.provider._fetch_userinfo = simple_async_mock(return_value=userinfo)
394404 auth_handler = self.hs.get_auth_handler()
395405 auth_handler.complete_sso_login = simple_async_mock()
396406
400410 client_redirect_url = "http://client/redirect"
401411 user_agent = "Browser"
402412 ip_address = "10.0.0.1"
403 session = self.handler._generate_oidc_session_token(
404 state=state,
405 nonce=nonce,
406 client_redirect_url=client_redirect_url,
407 ui_auth_session_id=None,
408 )
413 session = self._generate_oidc_session_token(state, nonce, client_redirect_url)
409414 request = _build_callback_request(
410415 code, state, session, user_agent=user_agent, ip_address=ip_address
411416 )
415420 auth_handler.complete_sso_login.assert_called_once_with(
416421 expected_user_id, request, client_redirect_url, None,
417422 )
418 self.handler._exchange_code.assert_called_once_with(code)
419 self.handler._parse_id_token.assert_called_once_with(token, nonce=nonce)
420 self.handler._fetch_userinfo.assert_not_called()
423 self.provider._exchange_code.assert_called_once_with(code)
424 self.provider._parse_id_token.assert_called_once_with(token, nonce=nonce)
425 self.provider._fetch_userinfo.assert_not_called()
421426 self.render_error.assert_not_called()
422427
423428 # Handle mapping errors
424429 with patch.object(
425 self.handler,
430 self.provider,
426431 "_remote_id_from_userinfo",
427432 new=Mock(side_effect=MappingException()),
428433 ):
430435 self.assertRenderedError("mapping_error")
431436
432437 # Handle ID token errors
433 self.handler._parse_id_token = simple_async_mock(raises=Exception())
438 self.provider._parse_id_token = simple_async_mock(raises=Exception())
434439 self.get_success(self.handler.handle_oidc_callback(request))
435440 self.assertRenderedError("invalid_token")
436441
437442 auth_handler.complete_sso_login.reset_mock()
438 self.handler._exchange_code.reset_mock()
439 self.handler._parse_id_token.reset_mock()
440 self.handler._fetch_userinfo.reset_mock()
443 self.provider._exchange_code.reset_mock()
444 self.provider._parse_id_token.reset_mock()
445 self.provider._fetch_userinfo.reset_mock()
441446
442447 # With userinfo fetching
443 self.handler._scopes = [] # do not ask the "openid" scope
448 self.provider._scopes = [] # do not ask the "openid" scope
444449 self.get_success(self.handler.handle_oidc_callback(request))
445450
446451 auth_handler.complete_sso_login.assert_called_once_with(
447452 expected_user_id, request, client_redirect_url, None,
448453 )
449 self.handler._exchange_code.assert_called_once_with(code)
450 self.handler._parse_id_token.assert_not_called()
451 self.handler._fetch_userinfo.assert_called_once_with(token)
454 self.provider._exchange_code.assert_called_once_with(code)
455 self.provider._parse_id_token.assert_not_called()
456 self.provider._fetch_userinfo.assert_called_once_with(token)
452457 self.render_error.assert_not_called()
453458
454459 # Handle userinfo fetching error
455 self.handler._fetch_userinfo = simple_async_mock(raises=Exception())
460 self.provider._fetch_userinfo = simple_async_mock(raises=Exception())
456461 self.get_success(self.handler.handle_oidc_callback(request))
457462 self.assertRenderedError("fetch_error")
458463
459464 # Handle code exchange failure
460 self.handler._exchange_code = simple_async_mock(
465 from synapse.handlers.oidc_handler import OidcError
466
467 self.provider._exchange_code = simple_async_mock(
461468 raises=OidcError("invalid_request")
462469 )
463470 self.get_success(self.handler.handle_oidc_callback(request))
487494 self.assertRenderedError("invalid_session")
488495
489496 # Mismatching session
490 session = self.handler._generate_oidc_session_token(
491 state="state",
492 nonce="nonce",
493 client_redirect_url="http://client/redirect",
494 ui_auth_session_id=None,
497 session = self._generate_oidc_session_token(
498 state="state", nonce="nonce", client_redirect_url="http://client/redirect",
495499 )
496500 request.args = {}
497501 request.args[b"state"] = [b"mismatching state"]
515519 return_value=FakeResponse(code=200, phrase=b"OK", body=token_json)
516520 )
517521 code = "code"
518 ret = self.get_success(self.handler._exchange_code(code))
522 ret = self.get_success(self.provider._exchange_code(code))
519523 kwargs = self.http_client.request.call_args[1]
520524
521525 self.assertEqual(ret, token)
537541 body=b'{"error": "foo", "error_description": "bar"}',
538542 )
539543 )
540 exc = self.get_failure(self.handler._exchange_code(code), OidcError)
544 from synapse.handlers.oidc_handler import OidcError
545
546 exc = self.get_failure(self.provider._exchange_code(code), OidcError)
541547 self.assertEqual(exc.value.error, "foo")
542548 self.assertEqual(exc.value.error_description, "bar")
543549
547553 code=500, phrase=b"Internal Server Error", body=b"Not JSON",
548554 )
549555 )
550 exc = self.get_failure(self.handler._exchange_code(code), OidcError)
556 exc = self.get_failure(self.provider._exchange_code(code), OidcError)
551557 self.assertEqual(exc.value.error, "server_error")
552558
553559 # Internal server error with JSON body
559565 )
560566 )
561567
562 exc = self.get_failure(self.handler._exchange_code(code), OidcError)
568 exc = self.get_failure(self.provider._exchange_code(code), OidcError)
563569 self.assertEqual(exc.value.error, "internal_server_error")
564570
565571 # 4xx error without "error" field
566572 self.http_client.request = simple_async_mock(
567573 return_value=FakeResponse(code=400, phrase=b"Bad request", body=b"{}",)
568574 )
569 exc = self.get_failure(self.handler._exchange_code(code), OidcError)
575 exc = self.get_failure(self.provider._exchange_code(code), OidcError)
570576 self.assertEqual(exc.value.error, "server_error")
571577
572578 # 2xx error with "error" field
575581 code=200, phrase=b"OK", body=b'{"error": "some_error"}',
576582 )
577583 )
578 exc = self.get_failure(self.handler._exchange_code(code), OidcError)
584 exc = self.get_failure(self.provider._exchange_code(code), OidcError)
579585 self.assertEqual(exc.value.error, "some_error")
580586
581587 @override_config(
601607 "username": "foo",
602608 "phone": "1234567",
603609 }
604 self.handler._exchange_code = simple_async_mock(return_value=token)
605 self.handler._parse_id_token = simple_async_mock(return_value=userinfo)
610 self.provider._exchange_code = simple_async_mock(return_value=token)
611 self.provider._parse_id_token = simple_async_mock(return_value=userinfo)
606612 auth_handler = self.hs.get_auth_handler()
607613 auth_handler.complete_sso_login = simple_async_mock()
608614
609615 state = "state"
610616 client_redirect_url = "http://client/redirect"
611 session = self.handler._generate_oidc_session_token(
612 state=state,
613 nonce="nonce",
614 client_redirect_url=client_redirect_url,
615 ui_auth_session_id=None,
617 session = self._generate_oidc_session_token(
618 state=state, nonce="nonce", client_redirect_url=client_redirect_url,
616619 )
617620 request = _build_callback_request("code", state, session)
618621
826829 self.get_success(_make_callback_with_userinfo(self.hs, userinfo))
827830 self.assertRenderedError("mapping_error", "localpart is invalid: ")
828831
829
830 class UsernamePickerTestCase(HomeserverTestCase):
831 servlets = [login.register_servlets]
832
833 def default_config(self):
834 config = super().default_config()
835 config["public_baseurl"] = BASE_URL
836 oidc_config = {
837 "enabled": True,
838 "client_id": CLIENT_ID,
839 "client_secret": CLIENT_SECRET,
840 "issuer": ISSUER,
841 "scopes": SCOPES,
842 "user_mapping_provider": {
843 "config": {"display_name_template": "{{ user.displayname }}"}
844 },
845 }
846
847 # Update this config with what's in the default config so that
848 # override_config works as expected.
849 oidc_config.update(config.get("oidc_config", {}))
850 config["oidc_config"] = oidc_config
851
852 # whitelist this client URI so we redirect straight to it rather than
853 # serving a confirmation page
854 config["sso"] = {"client_whitelist": ["https://whitelisted.client"]}
855 return config
856
857 def create_resource_dict(self) -> Dict[str, Resource]:
858 d = super().create_resource_dict()
859 d["/_synapse/client/pick_username"] = pick_username_resource(self.hs)
860 return d
861
862 def test_username_picker(self):
863 """Test the happy path of a username picker flow."""
864 client_redirect_url = "https://whitelisted.client"
865
866 # first of all, mock up an OIDC callback to the OidcHandler, which should
867 # raise a RedirectException
868 userinfo = {"sub": "tester", "displayname": "Jonny"}
869 f = self.get_failure(
870 _make_callback_with_userinfo(
871 self.hs, userinfo, client_redirect_url=client_redirect_url
832 def _generate_oidc_session_token(
833 self,
834 state: str,
835 nonce: str,
836 client_redirect_url: str,
837 ui_auth_session_id: Optional[str] = None,
838 ) -> str:
839 from synapse.handlers.oidc_handler import OidcSessionData
840
841 return self.handler._token_generator.generate_oidc_session_token(
842 state=state,
843 session_data=OidcSessionData(
844 idp_id="oidc",
845 nonce=nonce,
846 client_redirect_url=client_redirect_url,
847 ui_auth_session_id=ui_auth_session_id,
872848 ),
873 RedirectException,
874 )
875
876 # check the Location and cookies returned by the RedirectException
877 self.assertEqual(f.value.location, b"/_synapse/client/pick_username")
878 cookieheader = f.value.cookies[0]
879 regex = re.compile(b"^username_mapping_session=([a-zA-Z]+);")
880 m = regex.search(cookieheader)
881 if not m:
882 self.fail("cookie header %s does not match %s" % (cookieheader, regex))
883
884 # introspect the sso handler a bit to check that the username mapping session
885 # looks ok.
886 session_id = m.group(1).decode("ascii")
887 username_mapping_sessions = self.hs.get_sso_handler()._username_mapping_sessions
888 self.assertIn(
889 session_id, username_mapping_sessions, "session id not found in map"
890 )
891 session = username_mapping_sessions[session_id]
892 self.assertEqual(session.remote_user_id, "tester")
893 self.assertEqual(session.display_name, "Jonny")
894 self.assertEqual(session.client_redirect_url, client_redirect_url)
895
896 # the expiry time should be about 15 minutes away
897 expected_expiry = self.clock.time_msec() + (15 * 60 * 1000)
898 self.assertApproximates(session.expiry_time_ms, expected_expiry, tolerance=1000)
899
900 # Now, submit a username to the username picker, which should serve a redirect
901 # back to the client
902 submit_path = f.value.location + b"/submit"
903 content = urlencode({b"username": b"bobby"}).encode("utf8")
904 chan = self.make_request(
905 "POST",
906 path=submit_path,
907 content=content,
908 content_is_form=True,
909 custom_headers=[
910 ("Cookie", cookieheader),
911 # old versions of twisted don't do form-parsing without a valid
912 # content-length header.
913 ("Content-Length", str(len(content))),
914 ],
915 )
916 self.assertEqual(chan.code, 302, chan.result)
917 location_headers = chan.headers.getRawHeaders("Location")
918 # ensure that the returned location starts with the requested redirect URL
919 self.assertEqual(
920 location_headers[0][: len(client_redirect_url)], client_redirect_url
921 )
922
923 # fish the login token out of the returned redirect uri
924 parts = urlparse(location_headers[0])
925 query = parse_qs(parts.query)
926 login_token = query["loginToken"][0]
927
928 # finally, submit the matrix login token to the login API, which gives us our
929 # matrix access token, mxid, and device id.
930 chan = self.make_request(
931 "POST", "/login", content={"type": "m.login.token", "token": login_token},
932 )
933 self.assertEqual(chan.code, 200, chan.result)
934 self.assertEqual(chan.json_body["user_id"], "@bobby:test")
849 )
935850
936851
937852 async def _make_callback_with_userinfo(
947862 userinfo: the OIDC userinfo dict
948863 client_redirect_url: the URL to redirect to on success.
949864 """
865 from synapse.handlers.oidc_handler import OidcSessionData
866
950867 handler = hs.get_oidc_handler()
951 handler._exchange_code = simple_async_mock(return_value={})
952 handler._parse_id_token = simple_async_mock(return_value=userinfo)
953 handler._fetch_userinfo = simple_async_mock(return_value=userinfo)
868 provider = handler._providers["oidc"]
869 provider._exchange_code = simple_async_mock(return_value={})
870 provider._parse_id_token = simple_async_mock(return_value=userinfo)
871 provider._fetch_userinfo = simple_async_mock(return_value=userinfo)
954872
955873 state = "state"
956 session = handler._generate_oidc_session_token(
874 session = handler._token_generator.generate_oidc_session_token(
957875 state=state,
958 nonce="nonce",
959 client_redirect_url=client_redirect_url,
960 ui_auth_session_id=None,
876 session_data=OidcSessionData(
877 idp_id="oidc", nonce="nonce", client_redirect_url=client_redirect_url,
878 ),
961879 )
962880 request = _build_callback_request("code", state, session)
963881
993911 "addCookie",
994912 "requestHeaders",
995913 "getClientIP",
996 "get_user_agent",
914 "getHeader",
997915 ]
998916 )
999917
1002920 request.args[b"code"] = [code.encode("utf-8")]
1003921 request.args[b"state"] = [state.encode("utf-8")]
1004922 request.getClientIP.return_value = ip_address
1005 request.get_user_agent.return_value = user_agent
1006923 return request
104104 "Frank",
105105 )
106106
107 # Set displayname to an empty string
108 yield defer.ensureDeferred(
109 self.handler.set_displayname(
110 self.frank, synapse.types.create_requester(self.frank), ""
111 )
112 )
113
114 self.assertIsNone(
115 (
116 yield defer.ensureDeferred(
117 self.store.get_profile_displayname(self.frank.localpart)
118 )
119 )
120 )
121
107122 @defer.inlineCallbacks
108123 def test_set_my_name_if_disabled(self):
109124 self.hs.config.enable_set_displayname = False
222237 "http://my.server/me.png",
223238 )
224239
240 # Set avatar to an empty string
241 yield defer.ensureDeferred(
242 self.handler.set_avatar_url(
243 self.frank, synapse.types.create_requester(self.frank), "",
244 )
245 )
246
247 self.assertIsNone(
248 (
249 yield defer.ensureDeferred(
250 self.store.get_profile_avatar_url(self.frank.localpart)
251 )
252 ),
253 )
254
225255 @defer.inlineCallbacks
226256 def test_set_my_avatar_if_disabled(self):
227257 self.hs.config.enable_set_avatar_url = False
261261
262262 def _mock_request():
263263 """Returns a mock which will stand in as a SynapseRequest"""
264 return Mock(spec=["getClientIP", "get_user_agent"])
264 return Mock(spec=["getClientIP", "getHeader"])
10941094 # Expire both caches and repeat the request
10951095 self.reactor.pump((10000.0,))
10961096
1097 # Repated the request, this time it should fail if the lookup fails.
1097 # Repeat the request, this time it should fail if the lookup fails.
10981098 fetch_d = defer.ensureDeferred(
10991099 self.well_known_resolver.get_well_known(b"testserv")
11001100 )
11291129 content=b'{ "m.server": "' + (b"a" * WELL_KNOWN_MAX_SIZE) + b'" }',
11301130 )
11311131
1132 # The result is sucessful, but disabled delegation.
1132 # The result is successful, but disabled delegation.
11331133 r = self.successResultOf(fetch_d)
11341134 self.assertIsNone(r.delegated_server)
11351135
0 # Copyright 2021 The Matrix.org Foundation C.I.C.
1 #
2 # Licensed under the Apache License, Version 2.0 (the "License");
3 # you may not use this file except in compliance with the License.
4 # You may obtain a copy of the License at
5 #
6 # http://www.apache.org/licenses/LICENSE-2.0
7 #
8 # Unless required by applicable law or agreed to in writing, software
9 # distributed under the License is distributed on an "AS IS" BASIS,
10 # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
11 # See the License for the specific language governing permissions and
12 # limitations under the License.
13
14 from io import BytesIO
15
16 from mock import Mock
17
18 from twisted.python.failure import Failure
19 from twisted.web.client import ResponseDone
20
21 from synapse.http.client import BodyExceededMaxSize, read_body_with_max_size
22
23 from tests.unittest import TestCase
24
25
26 class ReadBodyWithMaxSizeTests(TestCase):
27 def setUp(self):
28 """Start reading the body, returns the response, result and proto"""
29 self.response = Mock()
30 self.result = BytesIO()
31 self.deferred = read_body_with_max_size(self.response, self.result, 6)
32
33 # Fish the protocol out of the response.
34 self.protocol = self.response.deliverBody.call_args[0][0]
35 self.protocol.transport = Mock()
36
37 def _cleanup_error(self):
38 """Ensure that the error in the Deferred is handled gracefully."""
39 called = [False]
40
41 def errback(f):
42 called[0] = True
43
44 self.deferred.addErrback(errback)
45 self.assertTrue(called[0])
46
47 def test_no_error(self):
48 """A response that is NOT too large."""
49
50 # Start sending data.
51 self.protocol.dataReceived(b"12345")
52 # Close the connection.
53 self.protocol.connectionLost(Failure(ResponseDone()))
54
55 self.assertEqual(self.result.getvalue(), b"12345")
56 self.assertEqual(self.deferred.result, 5)
57
58 def test_too_large(self):
59 """A response which is too large raises an exception."""
60
61 # Start sending data.
62 self.protocol.dataReceived(b"1234567890")
63 # Close the connection.
64 self.protocol.connectionLost(Failure(ResponseDone()))
65
66 self.assertEqual(self.result.getvalue(), b"1234567890")
67 self.assertIsInstance(self.deferred.result, Failure)
68 self.assertIsInstance(self.deferred.result.value, BodyExceededMaxSize)
69 self._cleanup_error()
70
71 def test_multiple_packets(self):
72 """Data should be accummulated through mutliple packets."""
73
74 # Start sending data.
75 self.protocol.dataReceived(b"12")
76 self.protocol.dataReceived(b"34")
77 # Close the connection.
78 self.protocol.connectionLost(Failure(ResponseDone()))
79
80 self.assertEqual(self.result.getvalue(), b"1234")
81 self.assertEqual(self.deferred.result, 4)
82
83 def test_additional_data(self):
84 """A connection can receive data after being closed."""
85
86 # Start sending data.
87 self.protocol.dataReceived(b"1234567890")
88 self.assertIsInstance(self.deferred.result, Failure)
89 self.assertIsInstance(self.deferred.result.value, BodyExceededMaxSize)
90 self.protocol.transport.loseConnection.assert_called_once()
91
92 # More data might have come in.
93 self.protocol.dataReceived(b"1234567890")
94 # Close the connection.
95 self.protocol.connectionLost(Failure(ResponseDone()))
96
97 self.assertEqual(self.result.getvalue(), b"1234567890")
98 self.assertIsInstance(self.deferred.result, Failure)
99 self.assertIsInstance(self.deferred.result.value, BodyExceededMaxSize)
100 self._cleanup_error()
1111 # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
1212 # See the License for the specific language governing permissions and
1313 # limitations under the License.
14 from synapse.http.endpoint import parse_and_validate_server_name, parse_server_name
14 from synapse.util.stringutils import parse_and_validate_server_name, parse_server_name
1515
1616 from tests import unittest
1717
559559 self.pump()
560560
561561 f = self.failureResultOf(test_d)
562 self.assertIsInstance(f.value, ValueError)
562 self.assertIsInstance(f.value, RequestSendFailed)
5757 ]
5858
5959 def prepare(self, reactor, clock, hs):
60 self.store = hs.get_datastore()
61
6260 self.admin_user = self.register_user("admin", "pass", admin=True)
6361 self.admin_user_tok = self.login("admin", "pass")
6462
154152 ]
155153
156154 def prepare(self, reactor, clock, hs):
157 self.store = hs.get_datastore()
158 self.hs = hs
159
160155 # Allow for uploading and downloading to/from the media repo
161156 self.media_repo = hs.get_media_repository_resource()
162157 self.download_resource = self.media_repo.children[b"download"]
430425
431426 # Mark the second item as safe from quarantine.
432427 _, media_id_2 = server_and_media_id_2.split("/")
433 self.get_success(self.store.mark_local_media_as_safe(media_id_2))
428 # Quarantine the media
429 url = "/_synapse/admin/v1/media/protect/%s" % (urllib.parse.quote(media_id_2),)
430 channel = self.make_request("POST", url, access_token=admin_user_tok)
431 self.pump(1.0)
432 self.assertEqual(200, int(channel.code), msg=channel.result["body"])
434433
435434 # Quarantine all media by this user
436435 url = "/_synapse/admin/v1/user/%s/media/quarantine" % urllib.parse.quote(
3131 ]
3232
3333 def prepare(self, reactor, clock, hs):
34 self.store = hs.get_datastore()
35
3634 self.admin_user = self.register_user("admin", "pass", admin=True)
3735 self.admin_user_tok = self.login("admin", "pass")
3836
370368 ]
371369
372370 def prepare(self, reactor, clock, hs):
373 self.store = hs.get_datastore()
374
375371 self.admin_user = self.register_user("admin", "pass", admin=True)
376372 self.admin_user_tok = self.login("admin", "pass")
377373
3434 ]
3535
3636 def prepare(self, reactor, clock, hs):
37 self.handler = hs.get_device_handler()
3837 self.media_repo = hs.get_media_repository_resource()
3938 self.server_name = hs.hostname
4039
180179 ]
181180
182181 def prepare(self, reactor, clock, hs):
183 self.handler = hs.get_device_handler()
184182 self.media_repo = hs.get_media_repository_resource()
185183 self.server_name = hs.hostname
186184
604604 ]
605605
606606 def prepare(self, reactor, clock, hs):
607 self.store = hs.get_datastore()
608
609607 # Create user
610608 self.admin_user = self.register_user("admin", "pass", admin=True)
611609 self.admin_user_tok = self.login("admin", "pass")
3030 ]
3131
3232 def prepare(self, reactor, clock, hs):
33 self.store = hs.get_datastore()
3433 self.media_repo = hs.get_media_repository_resource()
3534
3635 self.admin_user = self.register_user("admin", "pass", admin=True)
2424 import synapse.rest.admin
2525 from synapse.api.constants import UserTypes
2626 from synapse.api.errors import Codes, HttpResponseException, ResourceLimitError
27 from synapse.api.room_versions import RoomVersions
2728 from synapse.rest.client.v1 import login, logout, profile, room
2829 from synapse.rest.client.v2_alpha import devices, sync
2930
586587 _search_test(None, "bar", "user_id")
587588
588589
590 class DeactivateAccountTestCase(unittest.HomeserverTestCase):
591
592 servlets = [
593 synapse.rest.admin.register_servlets,
594 login.register_servlets,
595 ]
596
597 def prepare(self, reactor, clock, hs):
598 self.store = hs.get_datastore()
599
600 self.admin_user = self.register_user("admin", "pass", admin=True)
601 self.admin_user_tok = self.login("admin", "pass")
602
603 self.other_user = self.register_user("user", "pass", displayname="User1")
604 self.other_user_token = self.login("user", "pass")
605 self.url_other_user = "/_synapse/admin/v2/users/%s" % urllib.parse.quote(
606 self.other_user
607 )
608 self.url = "/_synapse/admin/v1/deactivate/%s" % urllib.parse.quote(
609 self.other_user
610 )
611
612 # set attributes for user
613 self.get_success(
614 self.store.set_profile_avatar_url("user", "mxc://servername/mediaid")
615 )
616 self.get_success(
617 self.store.user_add_threepid("@user:test", "email", "foo@bar.com", 0, 0)
618 )
619
620 def test_no_auth(self):
621 """
622 Try to deactivate users without authentication.
623 """
624 channel = self.make_request("POST", self.url, b"{}")
625
626 self.assertEqual(401, int(channel.result["code"]), msg=channel.result["body"])
627 self.assertEqual(Codes.MISSING_TOKEN, channel.json_body["errcode"])
628
629 def test_requester_is_not_admin(self):
630 """
631 If the user is not a server admin, an error is returned.
632 """
633 url = "/_synapse/admin/v1/deactivate/@bob:test"
634
635 channel = self.make_request("POST", url, access_token=self.other_user_token)
636
637 self.assertEqual(403, int(channel.result["code"]), msg=channel.result["body"])
638 self.assertEqual("You are not a server admin", channel.json_body["error"])
639
640 channel = self.make_request(
641 "POST", url, access_token=self.other_user_token, content=b"{}",
642 )
643
644 self.assertEqual(403, int(channel.result["code"]), msg=channel.result["body"])
645 self.assertEqual("You are not a server admin", channel.json_body["error"])
646
647 def test_user_does_not_exist(self):
648 """
649 Tests that deactivation for a user that does not exist returns a 404
650 """
651
652 channel = self.make_request(
653 "POST",
654 "/_synapse/admin/v1/deactivate/@unknown_person:test",
655 access_token=self.admin_user_tok,
656 )
657
658 self.assertEqual(404, channel.code, msg=channel.json_body)
659 self.assertEqual(Codes.NOT_FOUND, channel.json_body["errcode"])
660
661 def test_erase_is_not_bool(self):
662 """
663 If parameter `erase` is not boolean, return an error
664 """
665 body = json.dumps({"erase": "False"})
666
667 channel = self.make_request(
668 "POST",
669 self.url,
670 content=body.encode(encoding="utf_8"),
671 access_token=self.admin_user_tok,
672 )
673
674 self.assertEqual(400, int(channel.result["code"]), msg=channel.result["body"])
675 self.assertEqual(Codes.BAD_JSON, channel.json_body["errcode"])
676
677 def test_user_is_not_local(self):
678 """
679 Tests that deactivation for a user that is not a local returns a 400
680 """
681 url = "/_synapse/admin/v1/deactivate/@unknown_person:unknown_domain"
682
683 channel = self.make_request("POST", url, access_token=self.admin_user_tok)
684
685 self.assertEqual(400, channel.code, msg=channel.json_body)
686 self.assertEqual("Can only deactivate local users", channel.json_body["error"])
687
688 def test_deactivate_user_erase_true(self):
689 """
690 Test deactivating an user and set `erase` to `true`
691 """
692
693 # Get user
694 channel = self.make_request(
695 "GET", self.url_other_user, access_token=self.admin_user_tok,
696 )
697
698 self.assertEqual(200, int(channel.result["code"]), msg=channel.result["body"])
699 self.assertEqual("@user:test", channel.json_body["name"])
700 self.assertEqual(False, channel.json_body["deactivated"])
701 self.assertEqual("foo@bar.com", channel.json_body["threepids"][0]["address"])
702 self.assertEqual("mxc://servername/mediaid", channel.json_body["avatar_url"])
703 self.assertEqual("User1", channel.json_body["displayname"])
704
705 # Deactivate user
706 body = json.dumps({"erase": True})
707
708 channel = self.make_request(
709 "POST",
710 self.url,
711 access_token=self.admin_user_tok,
712 content=body.encode(encoding="utf_8"),
713 )
714
715 self.assertEqual(200, int(channel.result["code"]), msg=channel.result["body"])
716
717 # Get user
718 channel = self.make_request(
719 "GET", self.url_other_user, access_token=self.admin_user_tok,
720 )
721
722 self.assertEqual(200, int(channel.result["code"]), msg=channel.result["body"])
723 self.assertEqual("@user:test", channel.json_body["name"])
724 self.assertEqual(True, channel.json_body["deactivated"])
725 self.assertEqual(0, len(channel.json_body["threepids"]))
726 self.assertIsNone(channel.json_body["avatar_url"])
727 self.assertIsNone(channel.json_body["displayname"])
728
729 self._is_erased("@user:test", True)
730
731 def test_deactivate_user_erase_false(self):
732 """
733 Test deactivating an user and set `erase` to `false`
734 """
735
736 # Get user
737 channel = self.make_request(
738 "GET", self.url_other_user, access_token=self.admin_user_tok,
739 )
740
741 self.assertEqual(200, int(channel.result["code"]), msg=channel.result["body"])
742 self.assertEqual("@user:test", channel.json_body["name"])
743 self.assertEqual(False, channel.json_body["deactivated"])
744 self.assertEqual("foo@bar.com", channel.json_body["threepids"][0]["address"])
745 self.assertEqual("mxc://servername/mediaid", channel.json_body["avatar_url"])
746 self.assertEqual("User1", channel.json_body["displayname"])
747
748 # Deactivate user
749 body = json.dumps({"erase": False})
750
751 channel = self.make_request(
752 "POST",
753 self.url,
754 access_token=self.admin_user_tok,
755 content=body.encode(encoding="utf_8"),
756 )
757
758 self.assertEqual(200, int(channel.result["code"]), msg=channel.result["body"])
759
760 # Get user
761 channel = self.make_request(
762 "GET", self.url_other_user, access_token=self.admin_user_tok,
763 )
764
765 self.assertEqual(200, int(channel.result["code"]), msg=channel.result["body"])
766 self.assertEqual("@user:test", channel.json_body["name"])
767 self.assertEqual(True, channel.json_body["deactivated"])
768 self.assertEqual(0, len(channel.json_body["threepids"]))
769 self.assertEqual("mxc://servername/mediaid", channel.json_body["avatar_url"])
770 self.assertEqual("User1", channel.json_body["displayname"])
771
772 self._is_erased("@user:test", False)
773
774 def _is_erased(self, user_id: str, expect: bool) -> None:
775 """Assert that the user is erased or not
776 """
777 d = self.store.is_user_erased(user_id)
778 if expect:
779 self.assertTrue(self.get_success(d))
780 else:
781 self.assertFalse(self.get_success(d))
782
783
589784 class UserRestTestCase(unittest.HomeserverTestCase):
590785
591786 servlets = [
9851180 Test deactivating another user.
9861181 """
9871182
1183 # set attributes for user
1184 self.get_success(
1185 self.store.set_profile_avatar_url("user", "mxc://servername/mediaid")
1186 )
1187 self.get_success(
1188 self.store.user_add_threepid("@user:test", "email", "foo@bar.com", 0, 0)
1189 )
1190
1191 # Get user
1192 channel = self.make_request(
1193 "GET", self.url_other_user, access_token=self.admin_user_tok,
1194 )
1195
1196 self.assertEqual(200, int(channel.result["code"]), msg=channel.result["body"])
1197 self.assertEqual("@user:test", channel.json_body["name"])
1198 self.assertEqual(False, channel.json_body["deactivated"])
1199 self.assertEqual("foo@bar.com", channel.json_body["threepids"][0]["address"])
1200 self.assertEqual("mxc://servername/mediaid", channel.json_body["avatar_url"])
1201 self.assertEqual("User", channel.json_body["displayname"])
1202
9881203 # Deactivate user
9891204 body = json.dumps({"deactivated": True})
9901205
9981213 self.assertEqual(200, int(channel.result["code"]), msg=channel.result["body"])
9991214 self.assertEqual("@user:test", channel.json_body["name"])
10001215 self.assertEqual(True, channel.json_body["deactivated"])
1216 self.assertEqual(0, len(channel.json_body["threepids"]))
1217 self.assertEqual("mxc://servername/mediaid", channel.json_body["avatar_url"])
1218 self.assertEqual("User", channel.json_body["displayname"])
10011219 # the user is deactivated, the threepid will be deleted
10021220
10031221 # Get user
10081226 self.assertEqual(200, int(channel.result["code"]), msg=channel.result["body"])
10091227 self.assertEqual("@user:test", channel.json_body["name"])
10101228 self.assertEqual(True, channel.json_body["deactivated"])
1229 self.assertEqual(0, len(channel.json_body["threepids"]))
1230 self.assertEqual("mxc://servername/mediaid", channel.json_body["avatar_url"])
1231 self.assertEqual("User", channel.json_body["displayname"])
10111232
10121233 @override_config({"user_directory": {"enabled": True, "search_all_users": True}})
10131234 def test_change_name_deactivate_user_user_directory(self):
12031424 ]
12041425
12051426 def prepare(self, reactor, clock, hs):
1206 self.store = hs.get_datastore()
1207
12081427 self.admin_user = self.register_user("admin", "pass", admin=True)
12091428 self.admin_user_tok = self.login("admin", "pass")
12101429
12351454
12361455 def test_user_does_not_exist(self):
12371456 """
1238 Tests that a lookup for a user that does not exist returns a 404
1457 Tests that a lookup for a user that does not exist returns an empty list
12391458 """
12401459 url = "/_synapse/admin/v1/users/@unknown_person:test/joined_rooms"
12411460 channel = self.make_request("GET", url, access_token=self.admin_user_tok,)
12421461
1243 self.assertEqual(404, channel.code, msg=channel.json_body)
1244 self.assertEqual(Codes.NOT_FOUND, channel.json_body["errcode"])
1462 self.assertEqual(200, channel.code, msg=channel.json_body)
1463 self.assertEqual(0, channel.json_body["total"])
1464 self.assertEqual(0, len(channel.json_body["joined_rooms"]))
12451465
12461466 def test_user_is_not_local(self):
12471467 """
1248 Tests that a lookup for a user that is not a local returns a 400
1468 Tests that a lookup for a user that is not a local and participates in no conversation returns an empty list
12491469 """
12501470 url = "/_synapse/admin/v1/users/@unknown_person:unknown_domain/joined_rooms"
12511471
12521472 channel = self.make_request("GET", url, access_token=self.admin_user_tok,)
12531473
1254 self.assertEqual(400, channel.code, msg=channel.json_body)
1255 self.assertEqual("Can only lookup local users", channel.json_body["error"])
1474 self.assertEqual(200, channel.code, msg=channel.json_body)
1475 self.assertEqual(0, channel.json_body["total"])
1476 self.assertEqual(0, len(channel.json_body["joined_rooms"]))
12561477
12571478 def test_no_memberships(self):
12581479 """
12821503 self.assertEqual(200, channel.code, msg=channel.json_body)
12831504 self.assertEqual(number_rooms, channel.json_body["total"])
12841505 self.assertEqual(number_rooms, len(channel.json_body["joined_rooms"]))
1506
1507 def test_get_rooms_with_nonlocal_user(self):
1508 """
1509 Tests that a normal lookup for rooms is successful with a non-local user
1510 """
1511
1512 other_user_tok = self.login("user", "pass")
1513 event_builder_factory = self.hs.get_event_builder_factory()
1514 event_creation_handler = self.hs.get_event_creation_handler()
1515 storage = self.hs.get_storage()
1516
1517 # Create two rooms, one with a local user only and one with both a local
1518 # and remote user.
1519 self.helper.create_room_as(self.other_user, tok=other_user_tok)
1520 local_and_remote_room_id = self.helper.create_room_as(
1521 self.other_user, tok=other_user_tok
1522 )
1523
1524 # Add a remote user to the room.
1525 builder = event_builder_factory.for_room_version(
1526 RoomVersions.V1,
1527 {
1528 "type": "m.room.member",
1529 "sender": "@joiner:remote_hs",
1530 "state_key": "@joiner:remote_hs",
1531 "room_id": local_and_remote_room_id,
1532 "content": {"membership": "join"},
1533 },
1534 )
1535
1536 event, context = self.get_success(
1537 event_creation_handler.create_new_client_event(builder)
1538 )
1539
1540 self.get_success(storage.persistence.persist_event(event, context))
1541
1542 # Now get rooms
1543 url = "/_synapse/admin/v1/users/@joiner:remote_hs/joined_rooms"
1544 channel = self.make_request("GET", url, access_token=self.admin_user_tok,)
1545
1546 self.assertEqual(200, channel.code, msg=channel.json_body)
1547 self.assertEqual(1, channel.json_body["total"])
1548 self.assertEqual([local_and_remote_room_id], channel.json_body["joined_rooms"])
12851549
12861550
12871551 class PushersRestTestCase(unittest.HomeserverTestCase):
14001664 ]
14011665
14021666 def prepare(self, reactor, clock, hs):
1403 self.store = hs.get_datastore()
14041667 self.media_repo = hs.get_media_repository_resource()
14051668
14061669 self.admin_user = self.register_user("admin", "pass", admin=True)
18672130 ]
18682131
18692132 def prepare(self, reactor, clock, hs):
1870 self.store = hs.get_datastore()
1871
18722133 self.admin_user = self.register_user("admin", "pass", admin=True)
18732134 self.admin_user_tok = self.login("admin", "pass")
18742135
0 import json
0 # -*- coding: utf-8 -*-
1 # Copyright 2019-2021 The Matrix.org Foundation C.I.C.
2 #
3 # Licensed under the Apache License, Version 2.0 (the "License");
4 # you may not use this file except in compliance with the License.
5 # You may obtain a copy of the License at
6 #
7 # http://www.apache.org/licenses/LICENSE-2.0
8 #
9 # Unless required by applicable law or agreed to in writing, software
10 # distributed under the License is distributed on an "AS IS" BASIS,
11 # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 # See the License for the specific language governing permissions and
13 # limitations under the License.
14
115 import time
216 import urllib.parse
17 from typing import Any, Dict, Union
18 from urllib.parse import urlencode
319
420 from mock import Mock
521
6 import jwt
22 import pymacaroons
23
24 from twisted.web.resource import Resource
725
826 import synapse.rest.admin
927 from synapse.appservice import ApplicationService
1028 from synapse.rest.client.v1 import login, logout
1129 from synapse.rest.client.v2_alpha import devices, register
1230 from synapse.rest.client.v2_alpha.account import WhoamiRestServlet
31 from synapse.rest.synapse.client.pick_idp import PickIdpResource
32 from synapse.rest.synapse.client.pick_username import pick_username_resource
33 from synapse.types import create_requester
1334
1435 from tests import unittest
15 from tests.unittest import override_config
36 from tests.handlers.test_oidc import HAS_OIDC
37 from tests.handlers.test_saml import has_saml2
38 from tests.rest.client.v1.utils import TEST_OIDC_AUTH_ENDPOINT, TEST_OIDC_CONFIG
39 from tests.test_utils.html_parsers import TestHtmlParser
40 from tests.unittest import HomeserverTestCase, override_config, skip_unless
41
42 try:
43 import jwt
44
45 HAS_JWT = True
46 except ImportError:
47 HAS_JWT = False
48
49
50 # public_base_url used in some tests
51 BASE_URL = "https://synapse/"
52
53 # CAS server used in some tests
54 CAS_SERVER = "https://fake.test"
55
56 # just enough to tell pysaml2 where to redirect to
57 SAML_SERVER = "https://test.saml.server/idp/sso"
58 TEST_SAML_METADATA = """
59 <md:EntityDescriptor xmlns:md="urn:oasis:names:tc:SAML:2.0:metadata">
60 <md:IDPSSODescriptor protocolSupportEnumeration="urn:oasis:names:tc:SAML:2.0:protocol">
61 <md:SingleSignOnService Binding="urn:oasis:names:tc:SAML:2.0:bindings:HTTP-Redirect" Location="%(SAML_SERVER)s"/>
62 </md:IDPSSODescriptor>
63 </md:EntityDescriptor>
64 """ % {
65 "SAML_SERVER": SAML_SERVER,
66 }
1667
1768 LOGIN_URL = b"/_matrix/client/r0/login"
1869 TEST_URL = b"/_matrix/client/r0/account/whoami"
70
71 # a (valid) url with some annoying characters in. %3D is =, %26 is &, %2B is +
72 TEST_CLIENT_REDIRECT_URL = 'https://x?<ab c>&q"+%3D%2B"="fö%26=o"'
73
74 # the query params in TEST_CLIENT_REDIRECT_URL
75 EXPECTED_CLIENT_REDIRECT_URL_PARAMS = [("<ab c>", ""), ('q" =+"', '"fö&=o"')]
1976
2077
2178 class LoginRestServletTestCase(unittest.HomeserverTestCase):
310367 self.assertEquals(channel.result["code"], b"200", channel.result)
311368
312369
313 class CASTestCase(unittest.HomeserverTestCase):
370 @skip_unless(has_saml2 and HAS_OIDC, "Requires SAML2 and OIDC")
371 class MultiSSOTestCase(unittest.HomeserverTestCase):
372 """Tests for homeservers with multiple SSO providers enabled"""
314373
315374 servlets = [
316375 login.register_servlets,
317376 ]
318377
378 def default_config(self) -> Dict[str, Any]:
379 config = super().default_config()
380
381 config["public_baseurl"] = BASE_URL
382
383 config["cas_config"] = {
384 "enabled": True,
385 "server_url": CAS_SERVER,
386 "service_url": "https://matrix.goodserver.com:8448",
387 }
388
389 config["saml2_config"] = {
390 "sp_config": {
391 "metadata": {"inline": [TEST_SAML_METADATA]},
392 # use the XMLSecurity backend to avoid relying on xmlsec1
393 "crypto_backend": "XMLSecurity",
394 },
395 }
396
397 # default OIDC provider
398 config["oidc_config"] = TEST_OIDC_CONFIG
399
400 # additional OIDC providers
401 config["oidc_providers"] = [
402 {
403 "idp_id": "idp1",
404 "idp_name": "IDP1",
405 "discover": False,
406 "issuer": "https://issuer1",
407 "client_id": "test-client-id",
408 "client_secret": "test-client-secret",
409 "scopes": ["profile"],
410 "authorization_endpoint": "https://issuer1/auth",
411 "token_endpoint": "https://issuer1/token",
412 "userinfo_endpoint": "https://issuer1/userinfo",
413 "user_mapping_provider": {
414 "config": {"localpart_template": "{{ user.sub }}"}
415 },
416 }
417 ]
418 return config
419
420 def create_resource_dict(self) -> Dict[str, Resource]:
421 from synapse.rest.oidc import OIDCResource
422
423 d = super().create_resource_dict()
424 d["/_synapse/client/pick_idp"] = PickIdpResource(self.hs)
425 d["/_synapse/oidc"] = OIDCResource(self.hs)
426 return d
427
428 def test_multi_sso_redirect(self):
429 """/login/sso/redirect should redirect to an identity picker"""
430 # first hit the redirect url, which should redirect to our idp picker
431 channel = self.make_request(
432 "GET",
433 "/_matrix/client/r0/login/sso/redirect?redirectUrl="
434 + urllib.parse.quote_plus(TEST_CLIENT_REDIRECT_URL),
435 )
436 self.assertEqual(channel.code, 302, channel.result)
437 uri = channel.headers.getRawHeaders("Location")[0]
438
439 # hitting that picker should give us some HTML
440 channel = self.make_request("GET", uri)
441 self.assertEqual(channel.code, 200, channel.result)
442
443 # parse the form to check it has fields assumed elsewhere in this class
444 p = TestHtmlParser()
445 p.feed(channel.result["body"].decode("utf-8"))
446 p.close()
447
448 self.assertCountEqual(p.radios["idp"], ["cas", "oidc", "oidc-idp1", "saml"])
449
450 self.assertEqual(p.hiddens["redirectUrl"], TEST_CLIENT_REDIRECT_URL)
451
452 def test_multi_sso_redirect_to_cas(self):
453 """If CAS is chosen, should redirect to the CAS server"""
454
455 channel = self.make_request(
456 "GET",
457 "/_synapse/client/pick_idp?redirectUrl="
458 + urllib.parse.quote_plus(TEST_CLIENT_REDIRECT_URL)
459 + "&idp=cas",
460 shorthand=False,
461 )
462 self.assertEqual(channel.code, 302, channel.result)
463 cas_uri = channel.headers.getRawHeaders("Location")[0]
464 cas_uri_path, cas_uri_query = cas_uri.split("?", 1)
465
466 # it should redirect us to the login page of the cas server
467 self.assertEqual(cas_uri_path, CAS_SERVER + "/login")
468
469 # check that the redirectUrl is correctly encoded in the service param - ie, the
470 # place that CAS will redirect to
471 cas_uri_params = urllib.parse.parse_qs(cas_uri_query)
472 service_uri = cas_uri_params["service"][0]
473 _, service_uri_query = service_uri.split("?", 1)
474 service_uri_params = urllib.parse.parse_qs(service_uri_query)
475 self.assertEqual(service_uri_params["redirectUrl"][0], TEST_CLIENT_REDIRECT_URL)
476
477 def test_multi_sso_redirect_to_saml(self):
478 """If SAML is chosen, should redirect to the SAML server"""
479 channel = self.make_request(
480 "GET",
481 "/_synapse/client/pick_idp?redirectUrl="
482 + urllib.parse.quote_plus(TEST_CLIENT_REDIRECT_URL)
483 + "&idp=saml",
484 )
485 self.assertEqual(channel.code, 302, channel.result)
486 saml_uri = channel.headers.getRawHeaders("Location")[0]
487 saml_uri_path, saml_uri_query = saml_uri.split("?", 1)
488
489 # it should redirect us to the login page of the SAML server
490 self.assertEqual(saml_uri_path, SAML_SERVER)
491
492 # the RelayState is used to carry the client redirect url
493 saml_uri_params = urllib.parse.parse_qs(saml_uri_query)
494 relay_state_param = saml_uri_params["RelayState"][0]
495 self.assertEqual(relay_state_param, TEST_CLIENT_REDIRECT_URL)
496
497 def test_login_via_oidc(self):
498 """If OIDC is chosen, should redirect to the OIDC auth endpoint"""
499
500 # pick the default OIDC provider
501 channel = self.make_request(
502 "GET",
503 "/_synapse/client/pick_idp?redirectUrl="
504 + urllib.parse.quote_plus(TEST_CLIENT_REDIRECT_URL)
505 + "&idp=oidc",
506 )
507 self.assertEqual(channel.code, 302, channel.result)
508 oidc_uri = channel.headers.getRawHeaders("Location")[0]
509 oidc_uri_path, oidc_uri_query = oidc_uri.split("?", 1)
510
511 # it should redirect us to the auth page of the OIDC server
512 self.assertEqual(oidc_uri_path, TEST_OIDC_AUTH_ENDPOINT)
513
514 # ... and should have set a cookie including the redirect url
515 cookies = dict(
516 h.split(";")[0].split("=", maxsplit=1)
517 for h in channel.headers.getRawHeaders("Set-Cookie")
518 )
519
520 oidc_session_cookie = cookies["oidc_session"]
521 macaroon = pymacaroons.Macaroon.deserialize(oidc_session_cookie)
522 self.assertEqual(
523 self._get_value_from_macaroon(macaroon, "client_redirect_url"),
524 TEST_CLIENT_REDIRECT_URL,
525 )
526
527 channel = self.helper.complete_oidc_auth(oidc_uri, cookies, {"sub": "user1"})
528
529 # that should serve a confirmation page
530 self.assertEqual(channel.code, 200, channel.result)
531 self.assertTrue(
532 channel.headers.getRawHeaders("Content-Type")[-1].startswith("text/html")
533 )
534 p = TestHtmlParser()
535 p.feed(channel.text_body)
536 p.close()
537
538 # ... which should contain our redirect link
539 self.assertEqual(len(p.links), 1)
540 path, query = p.links[0].split("?", 1)
541 self.assertEqual(path, "https://x")
542
543 # it will have url-encoded the params properly, so we'll have to parse them
544 params = urllib.parse.parse_qsl(
545 query, keep_blank_values=True, strict_parsing=True, errors="strict"
546 )
547 self.assertEqual(params[0:2], EXPECTED_CLIENT_REDIRECT_URL_PARAMS)
548 self.assertEqual(params[2][0], "loginToken")
549
550 # finally, submit the matrix login token to the login API, which gives us our
551 # matrix access token, mxid, and device id.
552 login_token = params[2][1]
553 chan = self.make_request(
554 "POST", "/login", content={"type": "m.login.token", "token": login_token},
555 )
556 self.assertEqual(chan.code, 200, chan.result)
557 self.assertEqual(chan.json_body["user_id"], "@user1:test")
558
559 def test_multi_sso_redirect_to_unknown(self):
560 """An unknown IdP should cause a 400"""
561 channel = self.make_request(
562 "GET", "/_synapse/client/pick_idp?redirectUrl=http://x&idp=xyz",
563 )
564 self.assertEqual(channel.code, 400, channel.result)
565
566 @staticmethod
567 def _get_value_from_macaroon(macaroon: pymacaroons.Macaroon, key: str) -> str:
568 prefix = key + " = "
569 for caveat in macaroon.caveats:
570 if caveat.caveat_id.startswith(prefix):
571 return caveat.caveat_id[len(prefix) :]
572 raise ValueError("No %s caveat in macaroon" % (key,))
573
574
575 class CASTestCase(unittest.HomeserverTestCase):
576
577 servlets = [
578 login.register_servlets,
579 ]
580
319581 def make_homeserver(self, reactor, clock):
320582 self.base_url = "https://matrix.goodserver.com/"
321583 self.redirect_path = "_synapse/client/login/sso/redirect/confirm"
323585 config = self.default_config()
324586 config["cas_config"] = {
325587 "enabled": True,
326 "server_url": "https://fake.test",
588 "server_url": CAS_SERVER,
327589 "service_url": "https://matrix.goodserver.com:8448",
328590 }
329591
384646 channel = self.make_request("GET", cas_ticket_url)
385647
386648 # Test that the response is HTML.
387 self.assertEqual(channel.code, 200)
649 self.assertEqual(channel.code, 200, channel.result)
388650 content_type_header_value = ""
389651 for header in channel.result.get("headers", []):
390652 if header[0] == b"Content-Type":
409671 }
410672 )
411673 def test_cas_redirect_whitelisted(self):
412 """Tests that the SSO login flow serves a redirect to a whitelisted url
413 """
674 """Tests that the SSO login flow serves a redirect to a whitelisted url"""
414675 self._test_redirect("https://legit-site.com/")
415676
416677 @override_config({"public_baseurl": "https://example.com"})
441702
442703 # Deactivate the account.
443704 self.get_success(
444 self.deactivate_account_handler.deactivate_account(self.user_id, False)
705 self.deactivate_account_handler.deactivate_account(
706 self.user_id, False, create_requester(self.user_id)
707 )
445708 )
446709
447710 # Request the CAS ticket.
458721 self.assertIn(b"SSO account deactivated", channel.result["body"])
459722
460723
724 @skip_unless(HAS_JWT, "requires jwt")
461725 class JWTTestCase(unittest.HomeserverTestCase):
462726 servlets = [
463727 synapse.rest.admin.register_servlets_for_client_rest_resource,
474738 self.hs.config.jwt_algorithm = self.jwt_algorithm
475739 return self.hs
476740
477 def jwt_encode(self, token: str, secret: str = jwt_secret) -> str:
741 def jwt_encode(self, payload: Dict[str, Any], secret: str = jwt_secret) -> str:
478742 # PyJWT 2.0.0 changed the return type of jwt.encode from bytes to str.
479 result = jwt.encode(token, secret, self.jwt_algorithm)
743 result = jwt.encode(
744 payload, secret, self.jwt_algorithm
745 ) # type: Union[str, bytes]
480746 if isinstance(result, bytes):
481747 return result.decode("ascii")
482748 return result
483749
484750 def jwt_login(self, *args):
485 params = json.dumps(
486 {"type": "org.matrix.login.jwt", "token": self.jwt_encode(*args)}
487 )
751 params = {"type": "org.matrix.login.jwt", "token": self.jwt_encode(*args)}
488752 channel = self.make_request(b"POST", LOGIN_URL, params)
489753 return channel
490754
616880 )
617881
618882 def test_login_no_token(self):
619 params = json.dumps({"type": "org.matrix.login.jwt"})
883 params = {"type": "org.matrix.login.jwt"}
620884 channel = self.make_request(b"POST", LOGIN_URL, params)
621885 self.assertEqual(channel.result["code"], b"403", channel.result)
622886 self.assertEqual(channel.json_body["errcode"], "M_FORBIDDEN")
626890 # The JWTPubKeyTestCase is a complement to JWTTestCase where we instead use
627891 # RSS256, with a public key configured in synapse as "jwt_secret", and tokens
628892 # signed by the private key.
893 @skip_unless(HAS_JWT, "requires jwt")
629894 class JWTPubKeyTestCase(unittest.HomeserverTestCase):
630895 servlets = [
631896 login.register_servlets,
683948 self.hs.config.jwt_algorithm = "RS256"
684949 return self.hs
685950
686 def jwt_encode(self, token: str, secret: str = jwt_privatekey) -> str:
951 def jwt_encode(self, payload: Dict[str, Any], secret: str = jwt_privatekey) -> str:
687952 # PyJWT 2.0.0 changed the return type of jwt.encode from bytes to str.
688 result = jwt.encode(token, secret, "RS256")
953 result = jwt.encode(payload, secret, "RS256") # type: Union[bytes,str]
689954 if isinstance(result, bytes):
690955 return result.decode("ascii")
691956 return result
692957
693958 def jwt_login(self, *args):
694 params = json.dumps(
695 {"type": "org.matrix.login.jwt", "token": self.jwt_encode(*args)}
696 )
959 params = {"type": "org.matrix.login.jwt", "token": self.jwt_encode(*args)}
697960 channel = self.make_request(b"POST", LOGIN_URL, params)
698961 return channel
699962
7631026 return self.hs
7641027
7651028 def test_login_appservice_user(self):
766 """Test that an appservice user can use /login
767 """
1029 """Test that an appservice user can use /login"""
7681030 self.register_as_user(AS_USER)
7691031
7701032 params = {
7781040 self.assertEquals(channel.result["code"], b"200", channel.result)
7791041
7801042 def test_login_appservice_user_bot(self):
781 """Test that the appservice bot can use /login
782 """
1043 """Test that the appservice bot can use /login"""
7831044 self.register_as_user(AS_USER)
7841045
7851046 params = {
7931054 self.assertEquals(channel.result["code"], b"200", channel.result)
7941055
7951056 def test_login_appservice_wrong_user(self):
796 """Test that non-as users cannot login with the as token
797 """
1057 """Test that non-as users cannot login with the as token"""
7981058 self.register_as_user(AS_USER)
7991059
8001060 params = {
8081068 self.assertEquals(channel.result["code"], b"403", channel.result)
8091069
8101070 def test_login_appservice_wrong_as(self):
811 """Test that as users cannot login with wrong as token
812 """
1071 """Test that as users cannot login with wrong as token"""
8131072 self.register_as_user(AS_USER)
8141073
8151074 params = {
8241083
8251084 def test_login_appservice_no_token(self):
8261085 """Test that users must provide a token when using the appservice
827 login method
1086 login method
8281087 """
8291088 self.register_as_user(AS_USER)
8301089
8351094 channel = self.make_request(b"POST", LOGIN_URL, params)
8361095
8371096 self.assertEquals(channel.result["code"], b"401", channel.result)
1097
1098
1099 @skip_unless(HAS_OIDC, "requires OIDC")
1100 class UsernamePickerTestCase(HomeserverTestCase):
1101 """Tests for the username picker flow of SSO login"""
1102
1103 servlets = [login.register_servlets]
1104
1105 def default_config(self):
1106 config = super().default_config()
1107 config["public_baseurl"] = BASE_URL
1108
1109 config["oidc_config"] = {}
1110 config["oidc_config"].update(TEST_OIDC_CONFIG)
1111 config["oidc_config"]["user_mapping_provider"] = {
1112 "config": {"display_name_template": "{{ user.displayname }}"}
1113 }
1114
1115 # whitelist this client URI so we redirect straight to it rather than
1116 # serving a confirmation page
1117 config["sso"] = {"client_whitelist": ["https://x"]}
1118 return config
1119
1120 def create_resource_dict(self) -> Dict[str, Resource]:
1121 from synapse.rest.oidc import OIDCResource
1122
1123 d = super().create_resource_dict()
1124 d["/_synapse/client/pick_username"] = pick_username_resource(self.hs)
1125 d["/_synapse/oidc"] = OIDCResource(self.hs)
1126 return d
1127
1128 def test_username_picker(self):
1129 """Test the happy path of a username picker flow."""
1130
1131 # do the start of the login flow
1132 channel = self.helper.auth_via_oidc(
1133 {"sub": "tester", "displayname": "Jonny"}, TEST_CLIENT_REDIRECT_URL
1134 )
1135
1136 # that should redirect to the username picker
1137 self.assertEqual(channel.code, 302, channel.result)
1138 picker_url = channel.headers.getRawHeaders("Location")[0]
1139 self.assertEqual(picker_url, "/_synapse/client/pick_username")
1140
1141 # ... with a username_mapping_session cookie
1142 cookies = {} # type: Dict[str,str]
1143 channel.extract_cookies(cookies)
1144 self.assertIn("username_mapping_session", cookies)
1145 session_id = cookies["username_mapping_session"]
1146
1147 # introspect the sso handler a bit to check that the username mapping session
1148 # looks ok.
1149 username_mapping_sessions = self.hs.get_sso_handler()._username_mapping_sessions
1150 self.assertIn(
1151 session_id, username_mapping_sessions, "session id not found in map",
1152 )
1153 session = username_mapping_sessions[session_id]
1154 self.assertEqual(session.remote_user_id, "tester")
1155 self.assertEqual(session.display_name, "Jonny")
1156 self.assertEqual(session.client_redirect_url, TEST_CLIENT_REDIRECT_URL)
1157
1158 # the expiry time should be about 15 minutes away
1159 expected_expiry = self.clock.time_msec() + (15 * 60 * 1000)
1160 self.assertApproximates(session.expiry_time_ms, expected_expiry, tolerance=1000)
1161
1162 # Now, submit a username to the username picker, which should serve a redirect
1163 # back to the client
1164 submit_path = picker_url + "/submit"
1165 content = urlencode({b"username": b"bobby"}).encode("utf8")
1166 chan = self.make_request(
1167 "POST",
1168 path=submit_path,
1169 content=content,
1170 content_is_form=True,
1171 custom_headers=[
1172 ("Cookie", "username_mapping_session=" + session_id),
1173 # old versions of twisted don't do form-parsing without a valid
1174 # content-length header.
1175 ("Content-Length", str(len(content))),
1176 ],
1177 )
1178 self.assertEqual(chan.code, 302, chan.result)
1179 location_headers = chan.headers.getRawHeaders("Location")
1180 # ensure that the returned location matches the requested redirect URL
1181 path, query = location_headers[0].split("?", 1)
1182 self.assertEqual(path, "https://x")
1183
1184 # it will have url-encoded the params properly, so we'll have to parse them
1185 params = urllib.parse.parse_qsl(
1186 query, keep_blank_values=True, strict_parsing=True, errors="strict"
1187 )
1188 self.assertEqual(params[0:2], EXPECTED_CLIENT_REDIRECT_URL_PARAMS)
1189 self.assertEqual(params[2][0], "loginToken")
1190
1191 # fish the login token out of the returned redirect uri
1192 login_token = params[2][1]
1193
1194 # finally, submit the matrix login token to the login API, which gives us our
1195 # matrix access token, mxid, and device id.
1196 chan = self.make_request(
1197 "POST", "/login", content={"type": "m.login.token", "token": login_token},
1198 )
1199 self.assertEqual(chan.code, 200, chan.result)
1200 self.assertEqual(chan.json_body["user_id"], "@bobby:test")
2828 from synapse.rest import admin
2929 from synapse.rest.client.v1 import directory, login, profile, room
3030 from synapse.rest.client.v2_alpha import account
31 from synapse.types import JsonDict, RoomAlias, UserID
31 from synapse.types import JsonDict, RoomAlias, UserID, create_requester
3232 from synapse.util.stringutils import random_string
3333
3434 from tests import unittest
16861686
16871687 deactivate_account_handler = self.hs.get_deactivate_account_handler()
16881688 self.get_success(
1689 deactivate_account_handler.deactivate_account(self.user_id, erase_data=True)
1689 deactivate_account_handler.deactivate_account(
1690 self.user_id, True, create_requester(self.user_id)
1691 )
16901692 )
16911693
16921694 # Invite another user in the room. This is needed because messages will be
11 # Copyright 2014-2016 OpenMarket Ltd
22 # Copyright 2017 Vector Creations Ltd
33 # Copyright 2018-2019 New Vector Ltd
4 # Copyright 2019-2020 The Matrix.org Foundation C.I.C.
4 # Copyright 2019-2021 The Matrix.org Foundation C.I.C.
55 #
66 # Licensed under the Apache License, Version 2.0 (the "License");
77 # you may not use this file except in compliance with the License.
1919 import re
2020 import time
2121 import urllib.parse
22 from typing import Any, Dict, Optional
22 from typing import Any, Dict, Mapping, MutableMapping, Optional
2323
2424 from mock import patch
2525
3131 from synapse.api.constants import Membership
3232 from synapse.types import JsonDict
3333
34 from tests.server import FakeSite, make_request
34 from tests.server import FakeChannel, FakeSite, make_request
3535 from tests.test_utils import FakeResponse
36 from tests.test_utils.html_parsers import TestHtmlParser
3637
3738
3839 @attr.s
361362 the normal places.
362363 """
363364 client_redirect_url = "https://x"
364
365 # first hit the redirect url (which will issue a cookie and state)
365 channel = self.auth_via_oidc({"sub": remote_user_id}, client_redirect_url)
366
367 # expect a confirmation page
368 assert channel.code == 200, channel.result
369
370 # fish the matrix login token out of the body of the confirmation page
371 m = re.search(
372 'a href="%s.*loginToken=([^"]*)"' % (client_redirect_url,),
373 channel.text_body,
374 )
375 assert m, channel.text_body
376 login_token = m.group(1)
377
378 # finally, submit the matrix login token to the login API, which gives us our
379 # matrix access token and device id.
366380 channel = make_request(
367381 self.hs.get_reactor(),
368382 self.site,
369 "GET",
370 "/login/sso/redirect?redirectUrl=" + client_redirect_url,
371 )
372 # that will redirect to the OIDC IdP, but we skip that and go straight
383 "POST",
384 "/login",
385 content={"type": "m.login.token", "token": login_token},
386 )
387 assert channel.code == 200
388 return channel.json_body
389
390 def auth_via_oidc(
391 self,
392 user_info_dict: JsonDict,
393 client_redirect_url: Optional[str] = None,
394 ui_auth_session_id: Optional[str] = None,
395 ) -> FakeChannel:
396 """Perform an OIDC authentication flow via a mock OIDC provider.
397
398 This can be used for either login or user-interactive auth.
399
400 Starts by making a request to the relevant synapse redirect endpoint, which is
401 expected to serve a 302 to the OIDC provider. We then make a request to the
402 OIDC callback endpoint, intercepting the HTTP requests that will get sent back
403 to the OIDC provider.
404
405 Requires that "oidc_config" in the homeserver config be set appropriately
406 (TEST_OIDC_CONFIG is a suitable example) - and by implication, needs a
407 "public_base_url".
408
409 Also requires the login servlet and the OIDC callback resource to be mounted at
410 the normal places.
411
412 Args:
413 user_info_dict: the remote userinfo that the OIDC provider should present.
414 Typically this should be '{"sub": "<remote user id>"}'.
415 client_redirect_url: for a login flow, the client redirect URL to pass to
416 the login redirect endpoint
417 ui_auth_session_id: if set, we will perform a UI Auth flow. The session id
418 of the UI auth.
419
420 Returns:
421 A FakeChannel containing the result of calling the OIDC callback endpoint.
422 Note that the response code may be a 200, 302 or 400 depending on how things
423 went.
424 """
425
426 cookies = {}
427
428 # if we're doing a ui auth, hit the ui auth redirect endpoint
429 if ui_auth_session_id:
430 # can't set the client redirect url for UI Auth
431 assert client_redirect_url is None
432 oauth_uri = self.initiate_sso_ui_auth(ui_auth_session_id, cookies)
433 else:
434 # otherwise, hit the login redirect endpoint
435 oauth_uri = self.initiate_sso_login(client_redirect_url, cookies)
436
437 # we now have a URI for the OIDC IdP, but we skip that and go straight
373438 # back to synapse's OIDC callback resource. However, we do need the "state"
374 # param that synapse passes to the IdP via query params, and the cookie that
375 # synapse passes to the client.
376 assert channel.code == 302
377 oauth_uri = channel.headers.getRawHeaders("Location")[0]
378 params = urllib.parse.parse_qs(urllib.parse.urlparse(oauth_uri).query)
379 redirect_uri = "%s?%s" % (
439 # param that synapse passes to the IdP via query params, as well as the cookie
440 # that synapse passes to the client.
441
442 oauth_uri_path, _ = oauth_uri.split("?", 1)
443 assert oauth_uri_path == TEST_OIDC_AUTH_ENDPOINT, (
444 "unexpected SSO URI " + oauth_uri_path
445 )
446 return self.complete_oidc_auth(oauth_uri, cookies, user_info_dict)
447
448 def complete_oidc_auth(
449 self, oauth_uri: str, cookies: Mapping[str, str], user_info_dict: JsonDict,
450 ) -> FakeChannel:
451 """Mock out an OIDC authentication flow
452
453 Assumes that an OIDC auth has been initiated by one of initiate_sso_login or
454 initiate_sso_ui_auth; completes the OIDC bits of the flow by making a request to
455 Synapse's OIDC callback endpoint, intercepting the HTTP requests that will get
456 sent back to the OIDC provider.
457
458 Requires the OIDC callback resource to be mounted at the normal place.
459
460 Args:
461 oauth_uri: the OIDC URI returned by synapse's redirect endpoint (ie,
462 from initiate_sso_login or initiate_sso_ui_auth).
463 cookies: the cookies set by synapse's redirect endpoint, which will be
464 sent back to the callback endpoint.
465 user_info_dict: the remote userinfo that the OIDC provider should present.
466 Typically this should be '{"sub": "<remote user id>"}'.
467
468 Returns:
469 A FakeChannel containing the result of calling the OIDC callback endpoint.
470 """
471 _, oauth_uri_qs = oauth_uri.split("?", 1)
472 params = urllib.parse.parse_qs(oauth_uri_qs)
473 callback_uri = "%s?%s" % (
380474 urllib.parse.urlparse(params["redirect_uri"][0]).path,
381475 urllib.parse.urlencode({"state": params["state"][0], "code": "TEST_CODE"}),
382476 )
383 cookies = {}
384 for h in channel.headers.getRawHeaders("Set-Cookie"):
385 parts = h.split(";")
386 k, v = parts[0].split("=", maxsplit=1)
387 cookies[k] = v
388477
389478 # before we hit the callback uri, stub out some methods in the http client so
390479 # that we don't have to handle full HTTPS requests.
391
392480 # (expected url, json response) pairs, in the order we expect them.
393481 expected_requests = [
394482 # first we get a hit to the token endpoint, which we tell to return
395483 # a dummy OIDC access token
396 ("https://issuer.test/token", {"access_token": "TEST"}),
484 (TEST_OIDC_TOKEN_ENDPOINT, {"access_token": "TEST"}),
397485 # and then one to the user_info endpoint, which returns our remote user id.
398 ("https://issuer.test/userinfo", {"sub": remote_user_id}),
486 (TEST_OIDC_USERINFO_ENDPOINT, user_info_dict),
399487 ]
400488
401489 async def mock_req(method: str, uri: str, data=None, headers=None):
412500 self.hs.get_reactor(),
413501 self.site,
414502 "GET",
415 redirect_uri,
503 callback_uri,
416504 custom_headers=[
417505 ("Cookie", "%s=%s" % (k, v)) for (k, v) in cookies.items()
418506 ],
419507 )
420
421 # expect a confirmation page
422 assert channel.code == 200
423
424 # fish the matrix login token out of the body of the confirmation page
425 m = re.search(
426 'a href="%s.*loginToken=([^"]*)"' % (client_redirect_url,),
427 channel.result["body"].decode("utf-8"),
428 )
429 assert m
430 login_token = m.group(1)
431
432 # finally, submit the matrix login token to the login API, which gives us our
433 # matrix access token and device id.
508 return channel
509
510 def initiate_sso_login(
511 self, client_redirect_url: Optional[str], cookies: MutableMapping[str, str]
512 ) -> str:
513 """Make a request to the login-via-sso redirect endpoint, and return the target
514
515 Assumes that exactly one SSO provider has been configured. Requires the login
516 servlet to be mounted.
517
518 Args:
519 client_redirect_url: the client redirect URL to pass to the login redirect
520 endpoint
521 cookies: any cookies returned will be added to this dict
522
523 Returns:
524 the URI that the client gets redirected to (ie, the SSO server)
525 """
526 params = {}
527 if client_redirect_url:
528 params["redirectUrl"] = client_redirect_url
529
530 # hit the redirect url (which will issue a cookie and state)
434531 channel = make_request(
435532 self.hs.get_reactor(),
436533 self.site,
437 "POST",
438 "/login",
439 content={"type": "m.login.token", "token": login_token},
440 )
441 assert channel.code == 200
442 return channel.json_body
534 "GET",
535 "/_matrix/client/r0/login/sso/redirect?" + urllib.parse.urlencode(params),
536 )
537
538 assert channel.code == 302
539 channel.extract_cookies(cookies)
540 return channel.headers.getRawHeaders("Location")[0]
541
542 def initiate_sso_ui_auth(
543 self, ui_auth_session_id: str, cookies: MutableMapping[str, str]
544 ) -> str:
545 """Make a request to the ui-auth-via-sso endpoint, and return the target
546
547 Assumes that exactly one SSO provider has been configured. Requires the
548 AuthRestServlet to be mounted.
549
550 Args:
551 ui_auth_session_id: the session id of the UI auth
552 cookies: any cookies returned will be added to this dict
553
554 Returns:
555 the URI that the client gets linked to (ie, the SSO server)
556 """
557 sso_redirect_endpoint = (
558 "/_matrix/client/r0/auth/m.login.sso/fallback/web?"
559 + urllib.parse.urlencode({"session": ui_auth_session_id})
560 )
561 # hit the redirect url (which will issue a cookie and state)
562 channel = make_request(
563 self.hs.get_reactor(), self.site, "GET", sso_redirect_endpoint
564 )
565 # that should serve a confirmation page
566 assert channel.code == 200, channel.text_body
567 channel.extract_cookies(cookies)
568
569 # parse the confirmation page to fish out the link.
570 p = TestHtmlParser()
571 p.feed(channel.text_body)
572 p.close()
573 assert len(p.links) == 1, "not exactly one link in confirmation page"
574 oauth_uri = p.links[0]
575 return oauth_uri
443576
444577
445578 # an 'oidc_config' suitable for login_via_oidc.
579 TEST_OIDC_AUTH_ENDPOINT = "https://issuer.test/auth"
580 TEST_OIDC_TOKEN_ENDPOINT = "https://issuer.test/token"
581 TEST_OIDC_USERINFO_ENDPOINT = "https://issuer.test/userinfo"
446582 TEST_OIDC_CONFIG = {
447583 "enabled": True,
448584 "discover": False,
450586 "client_id": "test-client-id",
451587 "client_secret": "test-client-secret",
452588 "scopes": ["profile"],
453 "authorization_endpoint": "https://z",
454 "token_endpoint": "https://issuer.test/token",
455 "userinfo_endpoint": "https://issuer.test/userinfo",
589 "authorization_endpoint": TEST_OIDC_AUTH_ENDPOINT,
590 "token_endpoint": TEST_OIDC_TOKEN_ENDPOINT,
591 "userinfo_endpoint": TEST_OIDC_USERINFO_ENDPOINT,
456592 "user_mapping_provider": {"config": {"localpart_template": "{{ user.sub }}"}},
457593 }
00 # -*- coding: utf-8 -*-
11 # Copyright 2018 New Vector
2 # Copyright 2020-2021 The Matrix.org Foundation C.I.C
23 #
34 # Licensed under the Apache License, Version 2.0 (the "License");
45 # you may not use this file except in compliance with the License.
1112 # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
1213 # See the License for the specific language governing permissions and
1314 # limitations under the License.
14
1515 from typing import Union
1616
1717 from twisted.internet.defer import succeed
2525 from synapse.types import JsonDict, UserID
2626
2727 from tests import unittest
28 from tests.handlers.test_oidc import HAS_OIDC
2829 from tests.rest.client.v1.utils import TEST_OIDC_CONFIG
2930 from tests.server import FakeChannel
31 from tests.unittest import override_config, skip_unless
3032
3133
3234 class DummyRecaptchaChecker(UserInteractiveAuthChecker):
157159
158160 def default_config(self):
159161 config = super().default_config()
160
161 # we enable OIDC as a way of testing SSO flows
162 oidc_config = {}
163 oidc_config.update(TEST_OIDC_CONFIG)
164 oidc_config["allow_existing_users"] = True
165
166 config["oidc_config"] = oidc_config
167162 config["public_baseurl"] = "https://synapse.test"
163
164 if HAS_OIDC:
165 # we enable OIDC as a way of testing SSO flows
166 oidc_config = {}
167 oidc_config.update(TEST_OIDC_CONFIG)
168 oidc_config["allow_existing_users"] = True
169 config["oidc_config"] = oidc_config
170
168171 return config
169172
170173 def create_resource_dict(self):
171174 resource_dict = super().create_resource_dict()
172 # mount the OIDC resource at /_synapse/oidc
173 resource_dict["/_synapse/oidc"] = OIDCResource(self.hs)
175 if HAS_OIDC:
176 # mount the OIDC resource at /_synapse/oidc
177 resource_dict["/_synapse/oidc"] = OIDCResource(self.hs)
174178 return resource_dict
175179
176180 def prepare(self, reactor, clock, hs):
379383 # Note that *no auth* information is provided, not even a session iD!
380384 self.delete_device(self.user_tok, self.device_id, 200)
381385
386 @skip_unless(HAS_OIDC, "requires OIDC")
387 @override_config({"oidc_config": TEST_OIDC_CONFIG})
388 def test_ui_auth_via_sso(self):
389 """Test a successful UI Auth flow via SSO
390
391 This includes:
392 * hitting the UIA SSO redirect endpoint
393 * checking it serves a confirmation page which links to the OIDC provider
394 * calling back to the synapse oidc callback
395 * checking that the original operation succeeds
396 """
397
398 # log the user in
399 remote_user_id = UserID.from_string(self.user).localpart
400 login_resp = self.helper.login_via_oidc(remote_user_id)
401 self.assertEqual(login_resp["user_id"], self.user)
402
403 # initiate a UI Auth process by attempting to delete the device
404 channel = self.delete_device(self.user_tok, self.device_id, 401)
405
406 # check that SSO is offered
407 flows = channel.json_body["flows"]
408 self.assertIn({"stages": ["m.login.sso"]}, flows)
409
410 # run the UIA-via-SSO flow
411 session_id = channel.json_body["session"]
412 channel = self.helper.auth_via_oidc(
413 {"sub": remote_user_id}, ui_auth_session_id=session_id
414 )
415
416 # that should serve a confirmation page
417 self.assertEqual(channel.code, 200, channel.result)
418
419 # and now the delete request should succeed.
420 self.delete_device(
421 self.user_tok, self.device_id, 200, body={"auth": {"session": session_id}},
422 )
423
424 @skip_unless(HAS_OIDC, "requires OIDC")
425 @override_config({"oidc_config": TEST_OIDC_CONFIG})
382426 def test_does_not_offer_password_for_sso_user(self):
383427 login_resp = self.helper.login_via_oidc("username")
384428 user_tok = login_resp["access_token"]
392436 self.assertEqual(flows, [{"stages": ["m.login.sso"]}])
393437
394438 def test_does_not_offer_sso_for_password_user(self):
395 # now call the device deletion API: we should get the option to auth with SSO
396 # and not password.
397439 channel = self.delete_device(self.user_tok, self.device_id, 401)
398440
399441 flows = channel.json_body["flows"]
400442 self.assertEqual(flows, [{"stages": ["m.login.password"]}])
401443
444 @skip_unless(HAS_OIDC, "requires OIDC")
445 @override_config({"oidc_config": TEST_OIDC_CONFIG})
402446 def test_offers_both_flows_for_upgraded_user(self):
403447 """A user that had a password and then logged in with SSO should get both flows
404448 """
412456 self.assertIn({"stages": ["m.login.password"]}, flows)
413457 self.assertIn({"stages": ["m.login.sso"]}, flows)
414458 self.assertEqual(len(flows), 2)
459
460 @skip_unless(HAS_OIDC, "requires OIDC")
461 @override_config({"oidc_config": TEST_OIDC_CONFIG})
462 def test_ui_auth_fails_for_incorrect_sso_user(self):
463 """If the user tries to authenticate with the wrong SSO user, they get an error
464 """
465 # log the user in
466 login_resp = self.helper.login_via_oidc(UserID.from_string(self.user).localpart)
467 self.assertEqual(login_resp["user_id"], self.user)
468
469 # start a UI Auth flow by attempting to delete a device
470 channel = self.delete_device(self.user_tok, self.device_id, 401)
471
472 flows = channel.json_body["flows"]
473 self.assertIn({"stages": ["m.login.sso"]}, flows)
474 session_id = channel.json_body["session"]
475
476 # do the OIDC auth, but auth as the wrong user
477 channel = self.helper.auth_via_oidc(
478 {"sub": "wrong_user"}, ui_auth_session_id=session_id
479 )
480
481 # that should return a failure message
482 self.assertSubstring("We were unable to validate", channel.text_body)
483
484 # ... and the delete op should now fail with a 403
485 self.delete_device(
486 self.user_tok, self.device_id, 403, body={"auth": {"session": session_id}}
487 )
2525 from tests import unittest
2626 from tests.server import FakeTransport
2727
28 try:
29 import lxml
30 except ImportError:
31 lxml = None
32
2833
2934 class URLPreviewTests(unittest.HomeserverTestCase):
35 if not lxml:
36 skip = "url preview feature requires lxml"
3037
3138 hijack_auth = True
3239 user_id = "@test:user"
3939 "m.identity_server": {"base_url": "https://testis"},
4040 },
4141 )
42
43 def test_well_known_no_public_baseurl(self):
44 self.hs.config.public_baseurl = None
45
46 channel = self.make_request(
47 "GET", "/.well-known/matrix/client", shorthand=False
48 )
49
50 self.assertEqual(channel.code, 404)
11 import logging
22 from collections import deque
33 from io import SEEK_END, BytesIO
4 from typing import Callable, Iterable, Optional, Tuple, Union
4 from typing import Callable, Iterable, MutableMapping, Optional, Tuple, Union
55
66 import attr
77 from typing_extensions import Deque
5050
5151 @property
5252 def json_body(self):
53 if not self.result:
54 raise Exception("No result yet.")
55 return json.loads(self.result["body"].decode("utf8"))
53 return json.loads(self.text_body)
54
55 @property
56 def text_body(self) -> str:
57 """The body of the result, utf-8-decoded.
58
59 Raises an exception if the request has not yet completed.
60 """
61 if not self.is_finished:
62 raise Exception("Request not yet completed")
63 return self.result["body"].decode("utf8")
64
65 def is_finished(self) -> bool:
66 """check if the response has been completely received"""
67 return self.result.get("done", False)
5668
5769 @property
5870 def code(self):
6173 return int(self.result["code"])
6274
6375 @property
64 def headers(self):
76 def headers(self) -> Headers:
6577 if not self.result:
6678 raise Exception("No result yet.")
6779 h = Headers()
123135 self._reactor.run()
124136 x = 0
125137
126 while not self.result.get("done"):
138 while not self.is_finished():
127139 # If there's a producer, tell it to resume producing so we get content
128140 if self._producer:
129141 self._producer.resumeProducing()
134146 raise TimedOutException("Timed out waiting for request to finish.")
135147
136148 self._reactor.advance(0.1)
149
150 def extract_cookies(self, cookies: MutableMapping[str, str]) -> None:
151 """Process the contents of any Set-Cookie headers in the response
152
153 Any cookines found are added to the given dict
154 """
155 for h in self.headers.getRawHeaders("Set-Cookie"):
156 parts = h.split(";")
157 k, v = parts[0].split("=", maxsplit=1)
158 cookies[k] = v
137159
138160
139161 class FakeSite:
0 # -*- coding: utf-8 -*-
1 # Copyright 2021 The Matrix.org Foundation C.I.C.
2 #
3 # Licensed under the Apache License, Version 2.0 (the "License");
4 # you may not use this file except in compliance with the License.
5 # You may obtain a copy of the License at
6 #
7 # http://www.apache.org/licenses/LICENSE-2.0
8 #
9 # Unless required by applicable law or agreed to in writing, software
10 # distributed under the License is distributed on an "AS IS" BASIS,
11 # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 # See the License for the specific language governing permissions and
13 # limitations under the License.
14
15 from typing import Iterable, Set
16
17 from synapse.api.constants import AccountDataTypes
18
19 from tests import unittest
20
21
22 class IgnoredUsersTestCase(unittest.HomeserverTestCase):
23 def prepare(self, hs, reactor, clock):
24 self.store = self.hs.get_datastore()
25 self.user = "@user:test"
26
27 def _update_ignore_list(
28 self, *ignored_user_ids: Iterable[str], ignorer_user_id: str = None
29 ) -> None:
30 """Update the account data to block the given users."""
31 if ignorer_user_id is None:
32 ignorer_user_id = self.user
33
34 self.get_success(
35 self.store.add_account_data_for_user(
36 ignorer_user_id,
37 AccountDataTypes.IGNORED_USER_LIST,
38 {"ignored_users": {u: {} for u in ignored_user_ids}},
39 )
40 )
41
42 def assert_ignorers(
43 self, ignored_user_id: str, expected_ignorer_user_ids: Set[str]
44 ) -> None:
45 self.assertEqual(
46 self.get_success(self.store.ignored_by(ignored_user_id)),
47 expected_ignorer_user_ids,
48 )
49
50 def test_ignoring_users(self):
51 """Basic adding/removing of users from the ignore list."""
52 self._update_ignore_list("@other:test", "@another:remote")
53
54 # Check a user which no one ignores.
55 self.assert_ignorers("@user:test", set())
56
57 # Check a local user which is ignored.
58 self.assert_ignorers("@other:test", {self.user})
59
60 # Check a remote user which is ignored.
61 self.assert_ignorers("@another:remote", {self.user})
62
63 # Add one user, remove one user, and leave one user.
64 self._update_ignore_list("@foo:test", "@another:remote")
65
66 # Check the removed user.
67 self.assert_ignorers("@other:test", set())
68
69 # Check the added user.
70 self.assert_ignorers("@foo:test", {self.user})
71
72 # Check the removed user.
73 self.assert_ignorers("@another:remote", {self.user})
74
75 def test_caching(self):
76 """Ensure that caching works properly between different users."""
77 # The first user ignores a user.
78 self._update_ignore_list("@other:test")
79 self.assert_ignorers("@other:test", {self.user})
80
81 # The second user ignores them.
82 self._update_ignore_list("@other:test", ignorer_user_id="@second:test")
83 self.assert_ignorers("@other:test", {self.user, "@second:test"})
84
85 # The first user un-ignores them.
86 self._update_ignore_list()
87 self.assert_ignorers("@other:test", {"@second:test"})
88
89 def test_invalid_data(self):
90 """Invalid data ends up clearing out the ignored users list."""
91 # Add some data and ensure it is there.
92 self._update_ignore_list("@other:test")
93 self.assert_ignorers("@other:test", {self.user})
94
95 # No ignored_users key.
96 self.get_success(
97 self.store.add_account_data_for_user(
98 self.user, AccountDataTypes.IGNORED_USER_LIST, {},
99 )
100 )
101
102 # No one ignores the user now.
103 self.assert_ignorers("@other:test", set())
104
105 # Add some data and ensure it is there.
106 self._update_ignore_list("@other:test")
107 self.assert_ignorers("@other:test", {self.user})
108
109 # Invalid data.
110 self.get_success(
111 self.store.add_account_data_for_user(
112 self.user,
113 AccountDataTypes.IGNORED_USER_LIST,
114 {"ignored_users": "unexpected"},
115 )
116 )
117
118 # No one ignores the user now.
119 self.assert_ignorers("@other:test", set())
0 # -*- coding: utf-8 -*-
1 # Copyright 2020 The Matrix.org Foundation C.I.C.
2 #
3 # Licensed under the Apache License, Version 2.0 (the 'License');
4 # you may not use this file except in compliance with the License.
5 # You may obtain a copy of the License at
6 #
7 # http://www.apache.org/licenses/LICENSE-2.0
8 #
9 # Unless required by applicable law or agreed to in writing, software
10 # distributed under the License is distributed on an 'AS IS' BASIS,
11 # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 # See the License for the specific language governing permissions and
13 # limitations under the License.
14
15 from typing import Dict, List, Set, Tuple
16
17 from twisted.trial import unittest
18
19 from synapse.api.constants import EventTypes
20 from synapse.api.room_versions import RoomVersions
21 from synapse.events import EventBase
22 from synapse.rest import admin
23 from synapse.rest.client.v1 import login, room
24 from synapse.storage.databases.main.events import _LinkMap
25 from synapse.types import create_requester
26
27 from tests.unittest import HomeserverTestCase
28
29
30 class EventChainStoreTestCase(HomeserverTestCase):
31 def prepare(self, reactor, clock, hs):
32 self.store = hs.get_datastore()
33 self._next_stream_ordering = 1
34
35 def test_simple(self):
36 """Test that the example in `docs/auth_chain_difference_algorithm.md`
37 works.
38 """
39
40 event_factory = self.hs.get_event_builder_factory()
41 bob = "@creator:test"
42 alice = "@alice:test"
43 room_id = "!room:test"
44
45 # Ensure that we have a rooms entry so that we generate the chain index.
46 self.get_success(
47 self.store.store_room(
48 room_id=room_id,
49 room_creator_user_id="",
50 is_public=True,
51 room_version=RoomVersions.V6,
52 )
53 )
54
55 create = self.get_success(
56 event_factory.for_room_version(
57 RoomVersions.V6,
58 {
59 "type": EventTypes.Create,
60 "state_key": "",
61 "sender": bob,
62 "room_id": room_id,
63 "content": {"tag": "create"},
64 },
65 ).build(prev_event_ids=[], auth_event_ids=[])
66 )
67
68 bob_join = self.get_success(
69 event_factory.for_room_version(
70 RoomVersions.V6,
71 {
72 "type": EventTypes.Member,
73 "state_key": bob,
74 "sender": bob,
75 "room_id": room_id,
76 "content": {"tag": "bob_join"},
77 },
78 ).build(prev_event_ids=[], auth_event_ids=[create.event_id])
79 )
80
81 power = self.get_success(
82 event_factory.for_room_version(
83 RoomVersions.V6,
84 {
85 "type": EventTypes.PowerLevels,
86 "state_key": "",
87 "sender": bob,
88 "room_id": room_id,
89 "content": {"tag": "power"},
90 },
91 ).build(
92 prev_event_ids=[], auth_event_ids=[create.event_id, bob_join.event_id],
93 )
94 )
95
96 alice_invite = self.get_success(
97 event_factory.for_room_version(
98 RoomVersions.V6,
99 {
100 "type": EventTypes.Member,
101 "state_key": alice,
102 "sender": bob,
103 "room_id": room_id,
104 "content": {"tag": "alice_invite"},
105 },
106 ).build(
107 prev_event_ids=[],
108 auth_event_ids=[create.event_id, bob_join.event_id, power.event_id],
109 )
110 )
111
112 alice_join = self.get_success(
113 event_factory.for_room_version(
114 RoomVersions.V6,
115 {
116 "type": EventTypes.Member,
117 "state_key": alice,
118 "sender": alice,
119 "room_id": room_id,
120 "content": {"tag": "alice_join"},
121 },
122 ).build(
123 prev_event_ids=[],
124 auth_event_ids=[create.event_id, alice_invite.event_id, power.event_id],
125 )
126 )
127
128 power_2 = self.get_success(
129 event_factory.for_room_version(
130 RoomVersions.V6,
131 {
132 "type": EventTypes.PowerLevels,
133 "state_key": "",
134 "sender": bob,
135 "room_id": room_id,
136 "content": {"tag": "power_2"},
137 },
138 ).build(
139 prev_event_ids=[],
140 auth_event_ids=[create.event_id, bob_join.event_id, power.event_id],
141 )
142 )
143
144 bob_join_2 = self.get_success(
145 event_factory.for_room_version(
146 RoomVersions.V6,
147 {
148 "type": EventTypes.Member,
149 "state_key": bob,
150 "sender": bob,
151 "room_id": room_id,
152 "content": {"tag": "bob_join_2"},
153 },
154 ).build(
155 prev_event_ids=[],
156 auth_event_ids=[create.event_id, bob_join.event_id, power.event_id],
157 )
158 )
159
160 alice_join2 = self.get_success(
161 event_factory.for_room_version(
162 RoomVersions.V6,
163 {
164 "type": EventTypes.Member,
165 "state_key": alice,
166 "sender": alice,
167 "room_id": room_id,
168 "content": {"tag": "alice_join2"},
169 },
170 ).build(
171 prev_event_ids=[],
172 auth_event_ids=[
173 create.event_id,
174 alice_join.event_id,
175 power_2.event_id,
176 ],
177 )
178 )
179
180 events = [
181 create,
182 bob_join,
183 power,
184 alice_invite,
185 alice_join,
186 bob_join_2,
187 power_2,
188 alice_join2,
189 ]
190
191 expected_links = [
192 (bob_join, create),
193 (power, create),
194 (power, bob_join),
195 (alice_invite, create),
196 (alice_invite, power),
197 (alice_invite, bob_join),
198 (bob_join_2, power),
199 (alice_join2, power_2),
200 ]
201
202 self.persist(events)
203 chain_map, link_map = self.fetch_chains(events)
204
205 # Check that the expected links and only the expected links have been
206 # added.
207 self.assertEqual(len(expected_links), len(list(link_map.get_additions())))
208
209 for start, end in expected_links:
210 start_id, start_seq = chain_map[start.event_id]
211 end_id, end_seq = chain_map[end.event_id]
212
213 self.assertIn(
214 (start_seq, end_seq), list(link_map.get_links_between(start_id, end_id))
215 )
216
217 # Test that everything can reach the create event, but the create event
218 # can't reach anything.
219 for event in events[1:]:
220 self.assertTrue(
221 link_map.exists_path_from(
222 chain_map[event.event_id], chain_map[create.event_id]
223 ),
224 )
225
226 self.assertFalse(
227 link_map.exists_path_from(
228 chain_map[create.event_id], chain_map[event.event_id],
229 ),
230 )
231
232 def test_out_of_order_events(self):
233 """Test that we handle persisting events that we don't have the full
234 auth chain for yet (which should only happen for out of band memberships).
235 """
236 event_factory = self.hs.get_event_builder_factory()
237 bob = "@creator:test"
238 alice = "@alice:test"
239 room_id = "!room:test"
240
241 # Ensure that we have a rooms entry so that we generate the chain index.
242 self.get_success(
243 self.store.store_room(
244 room_id=room_id,
245 room_creator_user_id="",
246 is_public=True,
247 room_version=RoomVersions.V6,
248 )
249 )
250
251 # First persist the base room.
252 create = self.get_success(
253 event_factory.for_room_version(
254 RoomVersions.V6,
255 {
256 "type": EventTypes.Create,
257 "state_key": "",
258 "sender": bob,
259 "room_id": room_id,
260 "content": {"tag": "create"},
261 },
262 ).build(prev_event_ids=[], auth_event_ids=[])
263 )
264
265 bob_join = self.get_success(
266 event_factory.for_room_version(
267 RoomVersions.V6,
268 {
269 "type": EventTypes.Member,
270 "state_key": bob,
271 "sender": bob,
272 "room_id": room_id,
273 "content": {"tag": "bob_join"},
274 },
275 ).build(prev_event_ids=[], auth_event_ids=[create.event_id])
276 )
277
278 power = self.get_success(
279 event_factory.for_room_version(
280 RoomVersions.V6,
281 {
282 "type": EventTypes.PowerLevels,
283 "state_key": "",
284 "sender": bob,
285 "room_id": room_id,
286 "content": {"tag": "power"},
287 },
288 ).build(
289 prev_event_ids=[], auth_event_ids=[create.event_id, bob_join.event_id],
290 )
291 )
292
293 self.persist([create, bob_join, power])
294
295 # Now persist an invite and a couple of memberships out of order.
296 alice_invite = self.get_success(
297 event_factory.for_room_version(
298 RoomVersions.V6,
299 {
300 "type": EventTypes.Member,
301 "state_key": alice,
302 "sender": bob,
303 "room_id": room_id,
304 "content": {"tag": "alice_invite"},
305 },
306 ).build(
307 prev_event_ids=[],
308 auth_event_ids=[create.event_id, bob_join.event_id, power.event_id],
309 )
310 )
311
312 alice_join = self.get_success(
313 event_factory.for_room_version(
314 RoomVersions.V6,
315 {
316 "type": EventTypes.Member,
317 "state_key": alice,
318 "sender": alice,
319 "room_id": room_id,
320 "content": {"tag": "alice_join"},
321 },
322 ).build(
323 prev_event_ids=[],
324 auth_event_ids=[create.event_id, alice_invite.event_id, power.event_id],
325 )
326 )
327
328 alice_join2 = self.get_success(
329 event_factory.for_room_version(
330 RoomVersions.V6,
331 {
332 "type": EventTypes.Member,
333 "state_key": alice,
334 "sender": alice,
335 "room_id": room_id,
336 "content": {"tag": "alice_join2"},
337 },
338 ).build(
339 prev_event_ids=[],
340 auth_event_ids=[create.event_id, alice_join.event_id, power.event_id],
341 )
342 )
343
344 self.persist([alice_join])
345 self.persist([alice_join2])
346 self.persist([alice_invite])
347
348 # The end result should be sane.
349 events = [create, bob_join, power, alice_invite, alice_join]
350
351 chain_map, link_map = self.fetch_chains(events)
352
353 expected_links = [
354 (bob_join, create),
355 (power, create),
356 (power, bob_join),
357 (alice_invite, create),
358 (alice_invite, power),
359 (alice_invite, bob_join),
360 ]
361
362 # Check that the expected links and only the expected links have been
363 # added.
364 self.assertEqual(len(expected_links), len(list(link_map.get_additions())))
365
366 for start, end in expected_links:
367 start_id, start_seq = chain_map[start.event_id]
368 end_id, end_seq = chain_map[end.event_id]
369
370 self.assertIn(
371 (start_seq, end_seq), list(link_map.get_links_between(start_id, end_id))
372 )
373
374 def persist(
375 self, events: List[EventBase],
376 ):
377 """Persist the given events and check that the links generated match
378 those given.
379 """
380
381 persist_events_store = self.hs.get_datastores().persist_events
382
383 for e in events:
384 e.internal_metadata.stream_ordering = self._next_stream_ordering
385 self._next_stream_ordering += 1
386
387 def _persist(txn):
388 # We need to persist the events to the events and state_events
389 # tables.
390 persist_events_store._store_event_txn(txn, [(e, {}) for e in events])
391
392 # Actually call the function that calculates the auth chain stuff.
393 persist_events_store._persist_event_auth_chain_txn(txn, events)
394
395 self.get_success(
396 persist_events_store.db_pool.runInteraction("_persist", _persist,)
397 )
398
399 def fetch_chains(
400 self, events: List[EventBase]
401 ) -> Tuple[Dict[str, Tuple[int, int]], _LinkMap]:
402
403 # Fetch the map from event ID -> (chain ID, sequence number)
404 rows = self.get_success(
405 self.store.db_pool.simple_select_many_batch(
406 table="event_auth_chains",
407 column="event_id",
408 iterable=[e.event_id for e in events],
409 retcols=("event_id", "chain_id", "sequence_number"),
410 keyvalues={},
411 )
412 )
413
414 chain_map = {
415 row["event_id"]: (row["chain_id"], row["sequence_number"]) for row in rows
416 }
417
418 # Fetch all the links and pass them to the _LinkMap.
419 rows = self.get_success(
420 self.store.db_pool.simple_select_many_batch(
421 table="event_auth_chain_links",
422 column="origin_chain_id",
423 iterable=[chain_id for chain_id, _ in chain_map.values()],
424 retcols=(
425 "origin_chain_id",
426 "origin_sequence_number",
427 "target_chain_id",
428 "target_sequence_number",
429 ),
430 keyvalues={},
431 )
432 )
433
434 link_map = _LinkMap()
435 for row in rows:
436 added = link_map.add_link(
437 (row["origin_chain_id"], row["origin_sequence_number"]),
438 (row["target_chain_id"], row["target_sequence_number"]),
439 )
440
441 # We shouldn't have persisted any redundant links
442 self.assertTrue(added)
443
444 return chain_map, link_map
445
446
447 class LinkMapTestCase(unittest.TestCase):
448 def test_simple(self):
449 """Basic tests for the LinkMap.
450 """
451 link_map = _LinkMap()
452
453 link_map.add_link((1, 1), (2, 1), new=False)
454 self.assertCountEqual(link_map.get_links_between(1, 2), [(1, 1)])
455 self.assertCountEqual(link_map.get_links_from((1, 1)), [(2, 1)])
456 self.assertCountEqual(link_map.get_additions(), [])
457 self.assertTrue(link_map.exists_path_from((1, 5), (2, 1)))
458 self.assertFalse(link_map.exists_path_from((1, 5), (2, 2)))
459 self.assertTrue(link_map.exists_path_from((1, 5), (1, 1)))
460 self.assertFalse(link_map.exists_path_from((1, 1), (1, 5)))
461
462 # Attempting to add a redundant link is ignored.
463 self.assertFalse(link_map.add_link((1, 4), (2, 1)))
464 self.assertCountEqual(link_map.get_links_between(1, 2), [(1, 1)])
465
466 # Adding new non-redundant links works
467 self.assertTrue(link_map.add_link((1, 3), (2, 3)))
468 self.assertCountEqual(link_map.get_links_between(1, 2), [(1, 1), (3, 3)])
469
470 self.assertTrue(link_map.add_link((2, 5), (1, 3)))
471 self.assertCountEqual(link_map.get_links_between(2, 1), [(5, 3)])
472 self.assertCountEqual(link_map.get_links_between(1, 2), [(1, 1), (3, 3)])
473
474 self.assertCountEqual(link_map.get_additions(), [(1, 3, 2, 3), (2, 5, 1, 3)])
475
476
477 class EventChainBackgroundUpdateTestCase(HomeserverTestCase):
478
479 servlets = [
480 admin.register_servlets,
481 room.register_servlets,
482 login.register_servlets,
483 ]
484
485 def prepare(self, reactor, clock, hs):
486 self.store = hs.get_datastore()
487 self.user_id = self.register_user("foo", "pass")
488 self.token = self.login("foo", "pass")
489 self.requester = create_requester(self.user_id)
490
491 def _generate_room(self) -> Tuple[str, List[Set[str]]]:
492 """Insert a room without a chain cover index.
493 """
494 room_id = self.helper.create_room_as(self.user_id, tok=self.token)
495
496 # Mark the room as not having a chain cover index
497 self.get_success(
498 self.store.db_pool.simple_update(
499 table="rooms",
500 keyvalues={"room_id": room_id},
501 updatevalues={"has_auth_chain_index": False},
502 desc="test",
503 )
504 )
505
506 # Create a fork in the DAG with different events.
507 event_handler = self.hs.get_event_creation_handler()
508 latest_event_ids = self.get_success(
509 self.store.get_prev_events_for_room(room_id)
510 )
511 event, context = self.get_success(
512 event_handler.create_event(
513 self.requester,
514 {
515 "type": "some_state_type",
516 "state_key": "",
517 "content": {},
518 "room_id": room_id,
519 "sender": self.user_id,
520 },
521 prev_event_ids=latest_event_ids,
522 )
523 )
524 self.get_success(
525 event_handler.handle_new_client_event(self.requester, event, context)
526 )
527 state1 = set(self.get_success(context.get_current_state_ids()).values())
528
529 event, context = self.get_success(
530 event_handler.create_event(
531 self.requester,
532 {
533 "type": "some_state_type",
534 "state_key": "",
535 "content": {},
536 "room_id": room_id,
537 "sender": self.user_id,
538 },
539 prev_event_ids=latest_event_ids,
540 )
541 )
542 self.get_success(
543 event_handler.handle_new_client_event(self.requester, event, context)
544 )
545 state2 = set(self.get_success(context.get_current_state_ids()).values())
546
547 # Delete the chain cover info.
548
549 def _delete_tables(txn):
550 txn.execute("DELETE FROM event_auth_chains")
551 txn.execute("DELETE FROM event_auth_chain_links")
552
553 self.get_success(self.store.db_pool.runInteraction("test", _delete_tables))
554
555 return room_id, [state1, state2]
556
557 def test_background_update_single_room(self):
558 """Test that the background update to calculate auth chains for historic
559 rooms works correctly.
560 """
561
562 # Create a room
563 room_id, states = self._generate_room()
564
565 # Insert and run the background update.
566 self.get_success(
567 self.store.db_pool.simple_insert(
568 "background_updates",
569 {"update_name": "chain_cover", "progress_json": "{}"},
570 )
571 )
572
573 # Ugh, have to reset this flag
574 self.store.db_pool.updates._all_done = False
575
576 while not self.get_success(
577 self.store.db_pool.updates.has_completed_background_updates()
578 ):
579 self.get_success(
580 self.store.db_pool.updates.do_next_background_update(100), by=0.1
581 )
582
583 # Test that the `has_auth_chain_index` has been set
584 self.assertTrue(self.get_success(self.store.has_auth_chain_index(room_id)))
585
586 # Test that calculating the auth chain difference using the newly
587 # calculated chain cover works.
588 self.get_success(
589 self.store.db_pool.runInteraction(
590 "test",
591 self.store._get_auth_chain_difference_using_cover_index_txn,
592 room_id,
593 states,
594 )
595 )
596
597 def test_background_update_multiple_rooms(self):
598 """Test that the background update to calculate auth chains for historic
599 rooms works correctly.
600 """
601 # Create a room
602 room_id1, states1 = self._generate_room()
603 room_id2, states2 = self._generate_room()
604 room_id3, states2 = self._generate_room()
605
606 # Insert and run the background update.
607 self.get_success(
608 self.store.db_pool.simple_insert(
609 "background_updates",
610 {"update_name": "chain_cover", "progress_json": "{}"},
611 )
612 )
613
614 # Ugh, have to reset this flag
615 self.store.db_pool.updates._all_done = False
616
617 while not self.get_success(
618 self.store.db_pool.updates.has_completed_background_updates()
619 ):
620 self.get_success(
621 self.store.db_pool.updates.do_next_background_update(100), by=0.1
622 )
623
624 # Test that the `has_auth_chain_index` has been set
625 self.assertTrue(self.get_success(self.store.has_auth_chain_index(room_id1)))
626 self.assertTrue(self.get_success(self.store.has_auth_chain_index(room_id2)))
627 self.assertTrue(self.get_success(self.store.has_auth_chain_index(room_id3)))
628
629 # Test that calculating the auth chain difference using the newly
630 # calculated chain cover works.
631 self.get_success(
632 self.store.db_pool.runInteraction(
633 "test",
634 self.store._get_auth_chain_difference_using_cover_index_txn,
635 room_id1,
636 states1,
637 )
638 )
639
640 def test_background_update_single_large_room(self):
641 """Test that the background update to calculate auth chains for historic
642 rooms works correctly.
643 """
644
645 # Create a room
646 room_id, states = self._generate_room()
647
648 # Add a bunch of state so that it takes multiple iterations of the
649 # background update to process the room.
650 for i in range(0, 150):
651 self.helper.send_state(
652 room_id, event_type="m.test", body={"index": i}, tok=self.token
653 )
654
655 # Insert and run the background update.
656 self.get_success(
657 self.store.db_pool.simple_insert(
658 "background_updates",
659 {"update_name": "chain_cover", "progress_json": "{}"},
660 )
661 )
662
663 # Ugh, have to reset this flag
664 self.store.db_pool.updates._all_done = False
665
666 iterations = 0
667 while not self.get_success(
668 self.store.db_pool.updates.has_completed_background_updates()
669 ):
670 iterations += 1
671 self.get_success(
672 self.store.db_pool.updates.do_next_background_update(100), by=0.1
673 )
674
675 # Ensure that we did actually take multiple iterations to process the
676 # room.
677 self.assertGreater(iterations, 1)
678
679 # Test that the `has_auth_chain_index` has been set
680 self.assertTrue(self.get_success(self.store.has_auth_chain_index(room_id)))
681
682 # Test that calculating the auth chain difference using the newly
683 # calculated chain cover works.
684 self.get_success(
685 self.store.db_pool.runInteraction(
686 "test",
687 self.store._get_auth_chain_difference_using_cover_index_txn,
688 room_id,
689 states,
690 )
691 )
692
693 def test_background_update_multiple_large_room(self):
694 """Test that the background update to calculate auth chains for historic
695 rooms works correctly.
696 """
697
698 # Create the rooms
699 room_id1, _ = self._generate_room()
700 room_id2, _ = self._generate_room()
701
702 # Add a bunch of state so that it takes multiple iterations of the
703 # background update to process the room.
704 for i in range(0, 150):
705 self.helper.send_state(
706 room_id1, event_type="m.test", body={"index": i}, tok=self.token
707 )
708
709 for i in range(0, 150):
710 self.helper.send_state(
711 room_id2, event_type="m.test", body={"index": i}, tok=self.token
712 )
713
714 # Insert and run the background update.
715 self.get_success(
716 self.store.db_pool.simple_insert(
717 "background_updates",
718 {"update_name": "chain_cover", "progress_json": "{}"},
719 )
720 )
721
722 # Ugh, have to reset this flag
723 self.store.db_pool.updates._all_done = False
724
725 iterations = 0
726 while not self.get_success(
727 self.store.db_pool.updates.has_completed_background_updates()
728 ):
729 iterations += 1
730 self.get_success(
731 self.store.db_pool.updates.do_next_background_update(100), by=0.1
732 )
733
734 # Ensure that we did actually take multiple iterations to process the
735 # room.
736 self.assertGreater(iterations, 1)
737
738 # Test that the `has_auth_chain_index` has been set
739 self.assertTrue(self.get_success(self.store.has_auth_chain_index(room_id1)))
740 self.assertTrue(self.get_success(self.store.has_auth_chain_index(room_id2)))
1212 # See the License for the specific language governing permissions and
1313 # limitations under the License.
1414
15 import attr
16 from parameterized import parameterized
17
18 from synapse.events import _EventInternalMetadata
19
1520 import tests.unittest
1621 import tests.utils
1722
112117 r = self.get_success(self.store.get_rooms_with_many_extremities(5, 1, [room1]))
113118 self.assertTrue(r == [room2] or r == [room3])
114119
115 def test_auth_difference(self):
120 @parameterized.expand([(True,), (False,)])
121 def test_auth_difference(self, use_chain_cover_index: bool):
116122 room_id = "@ROOM:local"
117123
118124 # The silly auth graph we use to test the auth difference algorithm,
158164 "j": 1,
159165 }
160166
167 # Mark the room as not having a cover index
168
169 def store_room(txn):
170 self.store.db_pool.simple_insert_txn(
171 txn,
172 "rooms",
173 {
174 "room_id": room_id,
175 "creator": "room_creator_user_id",
176 "is_public": True,
177 "room_version": "6",
178 "has_auth_chain_index": use_chain_cover_index,
179 },
180 )
181
182 self.get_success(self.store.db_pool.runInteraction("store_room", store_room))
183
161184 # We rudely fiddle with the appropriate tables directly, as that's much
162185 # easier than constructing events properly.
163186
164 def insert_event(txn, event_id, stream_ordering):
165
166 depth = depth_map[event_id]
167
187 def insert_event(txn):
188 stream_ordering = 0
189
190 for event_id in auth_graph:
191 stream_ordering += 1
192 depth = depth_map[event_id]
193
194 self.store.db_pool.simple_insert_txn(
195 txn,
196 table="events",
197 values={
198 "event_id": event_id,
199 "room_id": room_id,
200 "depth": depth,
201 "topological_ordering": depth,
202 "type": "m.test",
203 "processed": True,
204 "outlier": False,
205 "stream_ordering": stream_ordering,
206 },
207 )
208
209 self.hs.datastores.persist_events._persist_event_auth_chain_txn(
210 txn,
211 [
212 FakeEvent(event_id, room_id, auth_graph[event_id])
213 for event_id in auth_graph
214 ],
215 )
216
217 self.get_success(self.store.db_pool.runInteraction("insert", insert_event,))
218
219 # Now actually test that various combinations give the right result:
220
221 difference = self.get_success(
222 self.store.get_auth_chain_difference(room_id, [{"a"}, {"b"}])
223 )
224 self.assertSetEqual(difference, {"a", "b"})
225
226 difference = self.get_success(
227 self.store.get_auth_chain_difference(room_id, [{"a"}, {"b"}, {"c"}])
228 )
229 self.assertSetEqual(difference, {"a", "b", "c", "e", "f"})
230
231 difference = self.get_success(
232 self.store.get_auth_chain_difference(room_id, [{"a", "c"}, {"b"}])
233 )
234 self.assertSetEqual(difference, {"a", "b", "c"})
235
236 difference = self.get_success(
237 self.store.get_auth_chain_difference(room_id, [{"a", "c"}, {"b", "c"}])
238 )
239 self.assertSetEqual(difference, {"a", "b"})
240
241 difference = self.get_success(
242 self.store.get_auth_chain_difference(room_id, [{"a"}, {"b"}, {"d"}])
243 )
244 self.assertSetEqual(difference, {"a", "b", "d", "e"})
245
246 difference = self.get_success(
247 self.store.get_auth_chain_difference(room_id, [{"a"}, {"b"}, {"c"}, {"d"}])
248 )
249 self.assertSetEqual(difference, {"a", "b", "c", "d", "e", "f"})
250
251 difference = self.get_success(
252 self.store.get_auth_chain_difference(room_id, [{"a"}, {"b"}, {"e"}])
253 )
254 self.assertSetEqual(difference, {"a", "b"})
255
256 difference = self.get_success(
257 self.store.get_auth_chain_difference(room_id, [{"a"}])
258 )
259 self.assertSetEqual(difference, set())
260
261 def test_auth_difference_partial_cover(self):
262 """Test that we correctly handle rooms where not all events have a chain
263 cover calculated. This can happen in some obscure edge cases, including
264 during the background update that calculates the chain cover for old
265 rooms.
266 """
267
268 room_id = "@ROOM:local"
269
270 # The silly auth graph we use to test the auth difference algorithm,
271 # where the top are the most recent events.
272 #
273 # A B
274 # \ /
275 # D E
276 # \ |
277 # ` F C
278 # | /|
279 # G ´ |
280 # | \ |
281 # H I
282 # | |
283 # K J
284
285 auth_graph = {
286 "a": ["e"],
287 "b": ["e"],
288 "c": ["g", "i"],
289 "d": ["f"],
290 "e": ["f"],
291 "f": ["g"],
292 "g": ["h", "i"],
293 "h": ["k"],
294 "i": ["j"],
295 "k": [],
296 "j": [],
297 }
298
299 depth_map = {
300 "a": 7,
301 "b": 7,
302 "c": 4,
303 "d": 6,
304 "e": 6,
305 "f": 5,
306 "g": 3,
307 "h": 2,
308 "i": 2,
309 "k": 1,
310 "j": 1,
311 }
312
313 # We rudely fiddle with the appropriate tables directly, as that's much
314 # easier than constructing events properly.
315
316 def insert_event(txn):
317 # First insert the room and mark it as having a chain cover.
168318 self.store.db_pool.simple_insert_txn(
169319 txn,
170 table="events",
171 values={
172 "event_id": event_id,
320 "rooms",
321 {
173322 "room_id": room_id,
174 "depth": depth,
175 "topological_ordering": depth,
176 "type": "m.test",
177 "processed": True,
178 "outlier": False,
179 "stream_ordering": stream_ordering,
323 "creator": "room_creator_user_id",
324 "is_public": True,
325 "room_version": "6",
326 "has_auth_chain_index": True,
180327 },
181328 )
182329
183 self.store.db_pool.simple_insert_many_txn(
184 txn,
185 table="event_auth",
186 values=[
187 {"event_id": event_id, "room_id": room_id, "auth_id": a}
188 for a in auth_graph[event_id]
330 stream_ordering = 0
331
332 for event_id in auth_graph:
333 stream_ordering += 1
334 depth = depth_map[event_id]
335
336 self.store.db_pool.simple_insert_txn(
337 txn,
338 table="events",
339 values={
340 "event_id": event_id,
341 "room_id": room_id,
342 "depth": depth,
343 "topological_ordering": depth,
344 "type": "m.test",
345 "processed": True,
346 "outlier": False,
347 "stream_ordering": stream_ordering,
348 },
349 )
350
351 # Insert all events apart from 'B'
352 self.hs.datastores.persist_events._persist_event_auth_chain_txn(
353 txn,
354 [
355 FakeEvent(event_id, room_id, auth_graph[event_id])
356 for event_id in auth_graph
357 if event_id != "b"
189358 ],
190359 )
191360
192 next_stream_ordering = 0
193 for event_id in auth_graph:
194 next_stream_ordering += 1
195 self.get_success(
196 self.store.db_pool.runInteraction(
197 "insert", insert_event, event_id, next_stream_ordering
198 )
199 )
361 # Now we insert the event 'B' without a chain cover, by temporarily
362 # pretending the room doesn't have a chain cover.
363
364 self.store.db_pool.simple_update_txn(
365 txn,
366 table="rooms",
367 keyvalues={"room_id": room_id},
368 updatevalues={"has_auth_chain_index": False},
369 )
370
371 self.hs.datastores.persist_events._persist_event_auth_chain_txn(
372 txn, [FakeEvent("b", room_id, auth_graph["b"])],
373 )
374
375 self.store.db_pool.simple_update_txn(
376 txn,
377 table="rooms",
378 keyvalues={"room_id": room_id},
379 updatevalues={"has_auth_chain_index": True},
380 )
381
382 self.get_success(self.store.db_pool.runInteraction("insert", insert_event,))
200383
201384 # Now actually test that various combinations give the right result:
202385
239422 self.store.get_auth_chain_difference(room_id, [{"a"}])
240423 )
241424 self.assertSetEqual(difference, set())
425
426
427 @attr.s
428 class FakeEvent:
429 event_id = attr.ib()
430 room_id = attr.ib()
431 auth_events = attr.ib()
432
433 type = "foo"
434 state_key = "foo"
435
436 internal_metadata = _EventInternalMetadata({})
437
438 def auth_event_ids(self):
439 return self.auth_events
440
441 def is_state(self):
442 return True
5050 self.db_pool,
5151 stream_name="test_stream",
5252 instance_name=instance_name,
53 table="foobar",
54 instance_column="instance_name",
55 id_column="stream_id",
53 tables=[("foobar", "instance_name", "stream_id")],
5654 sequence_name="foobar_seq",
5755 writers=writers,
5856 )
486484 self.db_pool,
487485 stream_name="test_stream",
488486 instance_name=instance_name,
489 table="foobar",
490 instance_column="instance_name",
491 id_column="stream_id",
487 tables=[("foobar", "instance_name", "stream_id")],
492488 sequence_name="foobar_seq",
493489 writers=writers,
494490 positive=False,
578574 self.assertEqual(id_gen_2.get_positions(), {"first": -1, "second": -2})
579575 self.assertEqual(id_gen_1.get_persisted_upto_position(), -2)
580576 self.assertEqual(id_gen_2.get_persisted_upto_position(), -2)
577
578
579 class MultiTableMultiWriterIdGeneratorTestCase(HomeserverTestCase):
580 if not USE_POSTGRES_FOR_TESTS:
581 skip = "Requires Postgres"
582
583 def prepare(self, reactor, clock, hs):
584 self.store = hs.get_datastore()
585 self.db_pool = self.store.db_pool # type: DatabasePool
586
587 self.get_success(self.db_pool.runInteraction("_setup_db", self._setup_db))
588
589 def _setup_db(self, txn):
590 txn.execute("CREATE SEQUENCE foobar_seq")
591 txn.execute(
592 """
593 CREATE TABLE foobar1 (
594 stream_id BIGINT NOT NULL,
595 instance_name TEXT NOT NULL,
596 data TEXT
597 );
598 """
599 )
600
601 txn.execute(
602 """
603 CREATE TABLE foobar2 (
604 stream_id BIGINT NOT NULL,
605 instance_name TEXT NOT NULL,
606 data TEXT
607 );
608 """
609 )
610
611 def _create_id_generator(
612 self, instance_name="master", writers=["master"]
613 ) -> MultiWriterIdGenerator:
614 def _create(conn):
615 return MultiWriterIdGenerator(
616 conn,
617 self.db_pool,
618 stream_name="test_stream",
619 instance_name=instance_name,
620 tables=[
621 ("foobar1", "instance_name", "stream_id"),
622 ("foobar2", "instance_name", "stream_id"),
623 ],
624 sequence_name="foobar_seq",
625 writers=writers,
626 )
627
628 return self.get_success_or_raise(self.db_pool.runWithConnection(_create))
629
630 def _insert_rows(
631 self,
632 table: str,
633 instance_name: str,
634 number: int,
635 update_stream_table: bool = True,
636 ):
637 """Insert N rows as the given instance, inserting with stream IDs pulled
638 from the postgres sequence.
639 """
640
641 def _insert(txn):
642 for _ in range(number):
643 txn.execute(
644 "INSERT INTO %s VALUES (nextval('foobar_seq'), ?)" % (table,),
645 (instance_name,),
646 )
647 if update_stream_table:
648 txn.execute(
649 """
650 INSERT INTO stream_positions VALUES ('test_stream', ?, lastval())
651 ON CONFLICT (stream_name, instance_name) DO UPDATE SET stream_id = lastval()
652 """,
653 (instance_name,),
654 )
655
656 self.get_success(self.db_pool.runInteraction("_insert_rows", _insert))
657
658 def test_load_existing_stream(self):
659 """Test creating ID gens with multiple tables that have rows from after
660 the position in `stream_positions` table.
661 """
662 self._insert_rows("foobar1", "first", 3)
663 self._insert_rows("foobar2", "second", 3)
664 self._insert_rows("foobar2", "second", 1, update_stream_table=False)
665
666 first_id_gen = self._create_id_generator("first", writers=["first", "second"])
667 second_id_gen = self._create_id_generator("second", writers=["first", "second"])
668
669 # The first ID gen will notice that it can advance its token to 7 as it
670 # has no in progress writes...
671 self.assertEqual(first_id_gen.get_positions(), {"first": 7, "second": 6})
672 self.assertEqual(first_id_gen.get_current_token_for_writer("first"), 7)
673 self.assertEqual(first_id_gen.get_current_token_for_writer("second"), 6)
674 self.assertEqual(first_id_gen.get_persisted_upto_position(), 7)
675
676 # ... but the second ID gen doesn't know that.
677 self.assertEqual(second_id_gen.get_positions(), {"first": 3, "second": 7})
678 self.assertEqual(second_id_gen.get_current_token_for_writer("first"), 3)
679 self.assertEqual(second_id_gen.get_current_token_for_writer("second"), 7)
680 self.assertEqual(first_id_gen.get_persisted_upto_position(), 7)
4747 ),
4848 )
4949
50 # test set to None
51 yield defer.ensureDeferred(
52 self.store.set_profile_displayname(self.u_frank.localpart, None)
53 )
54
55 self.assertIsNone(
56 (
57 yield defer.ensureDeferred(
58 self.store.get_profile_displayname(self.u_frank.localpart)
59 )
60 )
61 )
62
5063 @defer.inlineCallbacks
5164 def test_avatar_url(self):
5265 yield defer.ensureDeferred(self.store.create_profile(self.u_frank.localpart))
6578 )
6679 ),
6780 )
81
82 # test set to None
83 yield defer.ensureDeferred(
84 self.store.set_profile_avatar_url(self.u_frank.localpart, None)
85 )
86
87 self.assertIsNone(
88 (
89 yield defer.ensureDeferred(
90 self.store.get_profile_avatar_url(self.u_frank.localpart)
91 )
92 )
93 )
1919
2020 from . import unittest
2121
22 try:
23 import lxml
24 except ImportError:
25 lxml = None
26
2227
2328 class PreviewTestCase(unittest.TestCase):
29 if not lxml:
30 skip = "url preview feature requires lxml"
31
2432 def test_long_summarize(self):
2533 example_paras = [
2634 """Tromsø (Norwegian pronunciation: [ˈtrʊmsœ] ( listen); Northern Sami:
136144
137145
138146 class PreviewUrlTestCase(unittest.TestCase):
147 if not lxml:
148 skip = "url preview feature requires lxml"
149
139150 def test_simple(self):
140151 html = """
141152 <html>
5757
5858 self.assertEquals(room.to_string(), "#channel:my.domain")
5959
60 def test_validate(self):
61 id_string = "#test:domain,test"
62 self.assertFalse(RoomAlias.is_valid(id_string))
63
6064
6165 class GroupIDTestCase(unittest.TestCase):
6266 def test_parse(self):
0 # -*- coding: utf-8 -*-
1 # Copyright 2021 The Matrix.org Foundation C.I.C.
2 #
3 # Licensed under the Apache License, Version 2.0 (the "License");
4 # you may not use this file except in compliance with the License.
5 # You may obtain a copy of the License at
6 #
7 # http://www.apache.org/licenses/LICENSE-2.0
8 #
9 # Unless required by applicable law or agreed to in writing, software
10 # distributed under the License is distributed on an "AS IS" BASIS,
11 # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 # See the License for the specific language governing permissions and
13 # limitations under the License.
14
15 from html.parser import HTMLParser
16 from typing import Dict, Iterable, List, Optional, Tuple
17
18
19 class TestHtmlParser(HTMLParser):
20 """A generic HTML page parser which extracts useful things from the HTML"""
21
22 def __init__(self):
23 super().__init__()
24
25 # a list of links found in the doc
26 self.links = [] # type: List[str]
27
28 # the values of any hidden <input>s: map from name to value
29 self.hiddens = {} # type: Dict[str, Optional[str]]
30
31 # the values of any radio buttons: map from name to list of values
32 self.radios = {} # type: Dict[str, List[Optional[str]]]
33
34 def handle_starttag(
35 self, tag: str, attrs: Iterable[Tuple[str, Optional[str]]]
36 ) -> None:
37 attr_dict = dict(attrs)
38 if tag == "a":
39 href = attr_dict["href"]
40 if href:
41 self.links.append(href)
42 elif tag == "input":
43 input_name = attr_dict.get("name")
44 if attr_dict["type"] == "radio":
45 assert input_name
46 self.radios.setdefault(input_name, []).append(attr_dict["value"])
47 elif attr_dict["type"] == "hidden":
48 assert input_name
49 self.hiddens[input_name] = attr_dict["value"]
50
51 def error(_, message):
52 raise AssertionError(message)
1919 import inspect
2020 import logging
2121 import time
22 from typing import Dict, Iterable, Optional, Tuple, Type, TypeVar, Union
22 from typing import Callable, Dict, Iterable, Optional, Tuple, Type, TypeVar, Union
2323
2424 from mock import Mock, patch
2525
735735 return func
736736
737737 return decorator
738
739
740 TV = TypeVar("TV")
741
742
743 def skip_unless(condition: bool, reason: str) -> Callable[[TV], TV]:
744 """A test decorator which will skip the decorated test unless a condition is set
745
746 For example:
747
748 class MyTestCase(TestCase):
749 @skip_unless(HAS_FOO, "Cannot test without foo")
750 def test_foo(self):
751 ...
752
753 Args:
754 condition: If true, the test will be skipped
755 reason: the reason to give for skipping the test
756 """
757
758 def decorator(f: TV) -> TV:
759 if not condition:
760 f.skip = reason # type: ignore
761 return f
762
763 return decorator
2424 class DeferredCacheTestCase(TestCase):
2525 def test_empty(self):
2626 cache = DeferredCache("test")
27 failed = False
28 try:
27 with self.assertRaises(KeyError):
2928 cache.get("foo")
30 except KeyError:
31 failed = True
32
33 self.assertTrue(failed)
3429
3530 def test_hit(self):
3631 cache = DeferredCache("test")
154149 cache.prefill(("foo",), 123)
155150 cache.invalidate(("foo",))
156151
157 failed = False
158 try:
152 with self.assertRaises(KeyError):
159153 cache.get(("foo",))
160 except KeyError:
161 failed = True
162
163 self.assertTrue(failed)
164154
165155 def test_invalidate_all(self):
166156 cache = DeferredCache("testcache")
214204 cache.prefill(2, "two")
215205 cache.prefill(3, "three") # 1 will be evicted
216206
217 failed = False
218 try:
207 with self.assertRaises(KeyError):
219208 cache.get(1)
220 except KeyError:
221 failed = True
222
223 self.assertTrue(failed)
224209
225210 cache.get(2)
226211 cache.get(3)
238223
239224 cache.prefill(3, "three")
240225
241 failed = False
242 try:
226 with self.assertRaises(KeyError):
243227 cache.get(2)
244 except KeyError:
245 failed = True
246
247 self.assertTrue(failed)
248228
249229 cache.get(1)
250230 cache.get(3)
231
232 def test_eviction_iterable(self):
233 cache = DeferredCache(
234 "test", max_entries=3, apply_cache_factor_from_config=False, iterable=True,
235 )
236
237 cache.prefill(1, ["one", "two"])
238 cache.prefill(2, ["three"])
239
240 # Now access 1 again, thus causing 2 to be least-recently used
241 cache.get(1)
242
243 # Now add an item to the cache, which evicts 2.
244 cache.prefill(3, ["four"])
245 with self.assertRaises(KeyError):
246 cache.get(2)
247
248 # Ensure 1 & 3 are in the cache.
249 cache.get(1)
250 cache.get(3)
251
252 # Now access 1 again, thus causing 3 to be least-recently used
253 cache.get(1)
254
255 # Now add an item with multiple elements to the cache
256 cache.prefill(4, ["five", "six"])
257
258 # Both 1 and 3 are evicted since there's too many elements.
259 with self.assertRaises(KeyError):
260 cache.get(1)
261 with self.assertRaises(KeyError):
262 cache.get(3)
263
264 # Now add another item to fill the cache again.
265 cache.prefill(5, ["seven"])
266
267 # Now access 4, thus causing 5 to be least-recently used
268 cache.get(4)
269
270 # Add an empty item.
271 cache.prefill(6, [])
272
273 # 5 gets evicted and replaced since an empty element counts as an item.
274 with self.assertRaises(KeyError):
275 cache.get(5)
276 cache.get(4)
277 cache.get(6)
1111 # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
1212 # See the License for the specific language governing permissions and
1313 # limitations under the License.
14 from synapse.util.iterutils import chunk_seq
14 from typing import Dict, List
15
16 from synapse.util.iterutils import chunk_seq, sorted_topologically
1517
1618 from tests.unittest import TestCase
1719
4446 self.assertEqual(
4547 list(parts), [],
4648 )
49
50
51 class SortTopologically(TestCase):
52 def test_empty(self):
53 "Test that an empty graph works correctly"
54
55 graph = {} # type: Dict[int, List[int]]
56 self.assertEqual(list(sorted_topologically([], graph)), [])
57
58 def test_handle_empty_graph(self):
59 "Test that a graph where a node doesn't have an entry is treated as empty"
60
61 graph = {} # type: Dict[int, List[int]]
62
63 # For disconnected nodes the output is simply sorted.
64 self.assertEqual(list(sorted_topologically([1, 2], graph)), [1, 2])
65
66 def test_disconnected(self):
67 "Test that a graph with no edges work"
68
69 graph = {1: [], 2: []} # type: Dict[int, List[int]]
70
71 # For disconnected nodes the output is simply sorted.
72 self.assertEqual(list(sorted_topologically([1, 2], graph)), [1, 2])
73
74 def test_linear(self):
75 "Test that a simple `4 -> 3 -> 2 -> 1` graph works"
76
77 graph = {1: [], 2: [1], 3: [2], 4: [3]} # type: Dict[int, List[int]]
78
79 self.assertEqual(list(sorted_topologically([4, 3, 2, 1], graph)), [1, 2, 3, 4])
80
81 def test_subset(self):
82 "Test that only sorting a subset of the graph works"
83 graph = {1: [], 2: [1], 3: [2], 4: [3]} # type: Dict[int, List[int]]
84
85 self.assertEqual(list(sorted_topologically([4, 3], graph)), [3, 4])
86
87 def test_fork(self):
88 "Test that a forked graph works"
89 graph = {1: [], 2: [1], 3: [1], 4: [2, 3]} # type: Dict[int, List[int]]
90
91 # Valid orderings are `[1, 3, 2, 4]` or `[1, 2, 3, 4]`, but we should
92 # always get the same one.
93 self.assertEqual(list(sorted_topologically([4, 3, 2, 1], graph)), [1, 2, 3, 4])
94
95 def test_duplicates(self):
96 "Test that a graph with duplicate edges work"
97 graph = {1: [], 2: [1, 1], 3: [2, 2], 4: [3]} # type: Dict[int, List[int]]
98
99 self.assertEqual(list(sorted_topologically([4, 3, 2, 1], graph)), [1, 2, 3, 4])
100
101 def test_multiple_paths(self):
102 "Test that a graph with multiple paths between two nodes work"
103 graph = {1: [], 2: [1], 3: [2], 4: [3, 2, 1]} # type: Dict[int, List[int]]
104
105 self.assertEqual(list(sorted_topologically([4, 3, 2, 1], graph)), [1, 2, 3, 4])
158158 "remote": {"per_second": 10000, "burst_count": 10000},
159159 },
160160 "saml2_enabled": False,
161 "public_baseurl": None,
162161 "default_identity_server": None,
163162 "key_refresh_interval": 24 * 60 * 60 * 1000,
164163 "old_signing_keys": {},
11 envlist = packaging, py35, py36, py37, py38, py39, check_codestyle, check_isort
22
33 [base]
4 extras = test
54 deps =
65 python-subunit
76 junitxml
2423 # install the "enum34" dependency of cryptography.
2524 pip>=10
2625
26 # directories/files we run the linters on
27 lint_targets =
28 setup.py
29 synapse
30 tests
31 scripts
32 scripts-dev
33 stubs
34 contrib
35 synctl
36 synmark
37 .buildkite
38 docker
39
40 # default settings for all tox environments
2741 [testenv]
2842 deps =
2943 {[base]deps}
30 extras = all, test
44 extras =
45 # install the optional dependendencies for tox environments without
46 # '-noextras' in their name
47 !noextras: all
48 test
3149
3250 setenv =
3351 # use a postgres db for tox environments with "-postgres" in the name
84102 [testenv:py35-old]
85103 skip_install=True
86104 deps =
105 # Ensure a version of setuptools that supports Python 3.5 is installed.
106 setuptools < 51.0.0
107
87108 # Old automat version for Twisted
88109 Automat == 0.3.0
89110
95116 # Make all greater-thans equals so we test the oldest version of our direct
96117 # dependencies, but make the pyopenssl 17.0, which can work against an
97118 # OpenSSL 1.1 compiled cryptography (as older ones don't compile on Travis).
98 /bin/sh -c 'python -m synapse.python_dependencies | sed -e "s/>=/==/g" -e "s/psycopg2==2.6//" -e "s/pyopenssl==16.0.0/pyopenssl==17.0.0/" | xargs -d"\n" pip install'
119 /bin/sh -c 'python -m synapse.python_dependencies | sed -e "s/>=/==/g" -e "/psycopg2/d" -e "s/pyopenssl==16.0.0/pyopenssl==17.0.0/" | xargs -d"\n" pip install'
99120
100121 # Install Synapse itself. This won't update any libraries.
101122 pip install -e ".[test]"
125146 [testenv:check_codestyle]
126147 extras = lint
127148 commands =
128 python -m black --check --diff .
129 /bin/sh -c "flake8 synapse tests scripts scripts-dev contrib synctl {env:PEP8SUFFIX:}"
149 python -m black --check --diff {[base]lint_targets}
150 flake8 {[base]lint_targets} {env:PEP8SUFFIX:}
130151 {toxinidir}/scripts-dev/config-lint.sh
131152
132153 [testenv:check_isort]
133154 extras = lint
134 commands = /bin/sh -c "isort -c --df --sp setup.cfg synapse tests scripts-dev scripts"
155 commands = isort -c --df --sp setup.cfg {[base]lint_targets}
135156
136157 [testenv:check-newsfragment]
137158 skip_install = True