Codebase list matrix-synapse / e894f36
New upstream version 1.49.0 Andrej Shadura 2 years ago
165 changed file(s) with 7923 addition(s) and 2887 deletion(s). Raw diff Collapse all Expand all
373373 working-directory: complement/dockerfiles
374374
375375 # Run Complement
376 - run: go test -v -tags synapse_blacklist,msc2403,msc2946,msc3083 ./tests/...
376 - run: go test -v -tags synapse_blacklist,msc2403 ./tests/...
377377 env:
378378 COMPLEMENT_BASE_IMAGE: complement-synapse:latest
379379 working-directory: complement
0 Synapse 1.49.0 (2021-12-14)
1 ===========================
2
3 No significant changes since version 1.49.0rc1.
4
5
6 Support for Ubuntu 21.04 ends next month on the 20th of January
7 ---------------------------------------------------------------
8
9 For users of Ubuntu 21.04 (Hirsute Hippo), please be aware that [upstream support for this version of Ubuntu will end next month][Ubuntu2104EOL].
10 We will stop producing packages for Ubuntu 21.04 after upstream support ends.
11
12 [Ubuntu2104EOL]: https://lists.ubuntu.com/archives/ubuntu-announce/2021-December/000275.html
13
14
15 The wiki has been migrated to the documentation website
16 -------------------------------------------------------
17
18 We've decided to move the existing, somewhat stagnant pages from the GitHub wiki
19 to the [documentation website](https://matrix-org.github.io/synapse/latest/).
20
21 This was done for two reasons. The first was to ensure that changes are checked by
22 multiple authors before being committed (everyone makes mistakes!) and the second
23 was visibility of the documentation. Not everyone knows that Synapse has some very
24 useful information hidden away in its GitHub wiki pages. Bringing them to the
25 documentation website should help with visibility, as well as keep all Synapse documentation
26 in one, easily-searchable location.
27
28 Note that contributions to the documentation website happen through [GitHub pull
29 requests](https://github.com/matrix-org/synapse/pulls). Please visit [#synapse-dev:matrix.org](https://matrix.to/#/#synapse-dev:matrix.org)
30 if you need help with the process!
31
32
33 Synapse 1.49.0rc1 (2021-12-07)
34 ==============================
35
36 Features
37 --------
38
39 - Add [MSC3030](https://github.com/matrix-org/matrix-doc/pull/3030) experimental client and federation API endpoints to get the closest event to a given timestamp. ([\#9445](https://github.com/matrix-org/synapse/issues/9445))
40 - Include bundled relation aggregations during a limited `/sync` request and `/relations` request, per [MSC2675](https://github.com/matrix-org/matrix-doc/pull/2675). ([\#11284](https://github.com/matrix-org/synapse/issues/11284), [\#11478](https://github.com/matrix-org/synapse/issues/11478))
41 - Add plugin support for controlling database background updates. ([\#11306](https://github.com/matrix-org/synapse/issues/11306), [\#11475](https://github.com/matrix-org/synapse/issues/11475), [\#11479](https://github.com/matrix-org/synapse/issues/11479))
42 - Support the stable API endpoints for [MSC2946](https://github.com/matrix-org/matrix-doc/pull/2946): the room `/hierarchy` endpoint. ([\#11329](https://github.com/matrix-org/synapse/issues/11329))
43 - Add admin API to get some information about federation status with remote servers. ([\#11407](https://github.com/matrix-org/synapse/issues/11407))
44 - Support expiry of refresh tokens and expiry of the overall session when refresh tokens are in use. ([\#11425](https://github.com/matrix-org/synapse/issues/11425))
45 - Stabilise support for [MSC2918](https://github.com/matrix-org/matrix-doc/blob/main/proposals/2918-refreshtokens.md#msc2918-refresh-tokens) refresh tokens as they have now been merged into the Matrix specification. ([\#11435](https://github.com/matrix-org/synapse/issues/11435), [\#11522](https://github.com/matrix-org/synapse/issues/11522))
46 - Update [MSC2918 refresh token](https://github.com/matrix-org/matrix-doc/blob/main/proposals/2918-refreshtokens.md#msc2918-refresh-tokens) support to confirm with the latest revision: accept the `refresh_tokens` parameter in the request body rather than in the URL parameters. ([\#11430](https://github.com/matrix-org/synapse/issues/11430))
47 - Support configuring the lifetime of non-refreshable access tokens separately to refreshable access tokens. ([\#11445](https://github.com/matrix-org/synapse/issues/11445))
48 - Expose `synapse_homeserver` and `synapse_worker` commands as entry points to run Synapse's main process and worker processes, respectively. Contributed by @Ma27. ([\#11449](https://github.com/matrix-org/synapse/issues/11449))
49 - `synctl stop` will now wait for Synapse to exit before returning. ([\#11459](https://github.com/matrix-org/synapse/issues/11459), [\#11490](https://github.com/matrix-org/synapse/issues/11490))
50 - Extend the "delete room" admin api to work correctly on rooms which have previously been partially deleted. ([\#11523](https://github.com/matrix-org/synapse/issues/11523))
51 - Add support for the `/_matrix/client/v3/login/sso/redirect/{idpId}` API from Matrix v1.1. This endpoint was overlooked when support for v3 endpoints was added in Synapse 1.48.0rc1. ([\#11451](https://github.com/matrix-org/synapse/issues/11451))
52
53
54 Bugfixes
55 --------
56
57 - Fix using [MSC2716](https://github.com/matrix-org/matrix-doc/pull/2716) batch sending in combination with event persistence workers. Contributed by @tulir at Beeper. ([\#11220](https://github.com/matrix-org/synapse/issues/11220))
58 - Fix a long-standing bug where all requests that read events from the database could get stuck as a result of losing the database connection, properly this time. Also fix a race condition introduced in the previous insufficient fix in Synapse 1.47.0. ([\#11376](https://github.com/matrix-org/synapse/issues/11376))
59 - The `/send_join` response now includes the stable `event` field instead of the unstable field from [MSC3083](https://github.com/matrix-org/matrix-doc/pull/3083). ([\#11413](https://github.com/matrix-org/synapse/issues/11413))
60 - Fix a bug introduced in Synapse 1.47.0 where `send_join` could fail due to an outdated `ijson` version. ([\#11439](https://github.com/matrix-org/synapse/issues/11439), [\#11441](https://github.com/matrix-org/synapse/issues/11441), [\#11460](https://github.com/matrix-org/synapse/issues/11460))
61 - Fix a bug introduced in Synapse 1.36.0 which could cause problems fetching event-signing keys from trusted key servers. ([\#11440](https://github.com/matrix-org/synapse/issues/11440))
62 - Fix a bug introduced in Synapse 1.47.1 where the media repository would fail to work if the media store path contained any symbolic links. ([\#11446](https://github.com/matrix-org/synapse/issues/11446))
63 - Fix an `LruCache` corruption bug, introduced in Synapse 1.38.0, that would cause certain requests to fail until the next Synapse restart. ([\#11454](https://github.com/matrix-org/synapse/issues/11454))
64 - Fix a long-standing bug where invites from ignored users were included in incremental syncs. ([\#11511](https://github.com/matrix-org/synapse/issues/11511))
65 - Fix a regression in Synapse 1.48.0 where presence workers would not clear their presence updates over replication on shutdown. ([\#11518](https://github.com/matrix-org/synapse/issues/11518))
66 - Fix a regression in Synapse 1.48.0 where the module API's `looping_background_call` method would spam errors to the logs when given a non-async function. ([\#11524](https://github.com/matrix-org/synapse/issues/11524))
67
68
69 Updates to the Docker image
70 ---------------------------
71
72 - Update `Dockerfile-workers` to healthcheck all workers in the container. ([\#11429](https://github.com/matrix-org/synapse/issues/11429))
73
74
75 Improved Documentation
76 ----------------------
77
78 - Update the media repository documentation. ([\#11415](https://github.com/matrix-org/synapse/issues/11415))
79 - Update section about backward extremities in the room DAG concepts doc to correct the misconception about backward extremities indicating whether we have fetched an events' `prev_events`. ([\#11469](https://github.com/matrix-org/synapse/issues/11469))
80
81
82 Internal Changes
83 ----------------
84
85 - Add `Final` annotation to string constants in `synapse.api.constants` so that they get typed as `Literal`s. ([\#11356](https://github.com/matrix-org/synapse/issues/11356))
86 - Add a check to ensure that users cannot start the Synapse master process when `worker_app` is set. ([\#11416](https://github.com/matrix-org/synapse/issues/11416))
87 - Add a note about postgres memory management and hugepages to postgres doc. ([\#11467](https://github.com/matrix-org/synapse/issues/11467))
88 - Add missing type hints to `synapse.config` module. ([\#11465](https://github.com/matrix-org/synapse/issues/11465))
89 - Add missing type hints to `synapse.federation`. ([\#11483](https://github.com/matrix-org/synapse/issues/11483))
90 - Add type annotations to `tests.storage.test_appservice`. ([\#11488](https://github.com/matrix-org/synapse/issues/11488), [\#11492](https://github.com/matrix-org/synapse/issues/11492))
91 - Add type annotations to some of the configuration surrounding refresh tokens. ([\#11428](https://github.com/matrix-org/synapse/issues/11428))
92 - Add type hints to `synapse/tests/rest/admin`. ([\#11501](https://github.com/matrix-org/synapse/issues/11501))
93 - Add type hints to storage classes. ([\#11411](https://github.com/matrix-org/synapse/issues/11411))
94 - Add wiki pages to documentation website. ([\#11402](https://github.com/matrix-org/synapse/issues/11402))
95 - Clean up `tests.storage.test_main` to remove use of legacy code. ([\#11493](https://github.com/matrix-org/synapse/issues/11493))
96 - Clean up `tests.test_visibility` to remove legacy code. ([\#11495](https://github.com/matrix-org/synapse/issues/11495))
97 - Convert status codes to `HTTPStatus` in `synapse.rest.admin`. ([\#11452](https://github.com/matrix-org/synapse/issues/11452), [\#11455](https://github.com/matrix-org/synapse/issues/11455))
98 - Extend the `scripts-dev/sign_json` script to support signing events. ([\#11486](https://github.com/matrix-org/synapse/issues/11486))
99 - Improve internal types in push code. ([\#11409](https://github.com/matrix-org/synapse/issues/11409))
100 - Improve type annotations in `synapse.module_api`. ([\#11029](https://github.com/matrix-org/synapse/issues/11029))
101 - Improve type hints for `LruCache`. ([\#11453](https://github.com/matrix-org/synapse/issues/11453))
102 - Preparation for database schema simplifications: disambiguate queries on `state_key`. ([\#11497](https://github.com/matrix-org/synapse/issues/11497))
103 - Refactor `backfilled` into specific behavior function arguments (`_persist_events_and_state_updates` and downstream calls). ([\#11417](https://github.com/matrix-org/synapse/issues/11417))
104 - Refactor `get_version_string` to fix-up types and duplicated code. ([\#11468](https://github.com/matrix-org/synapse/issues/11468))
105 - Refactor various parts of the `/sync` handler. ([\#11494](https://github.com/matrix-org/synapse/issues/11494), [\#11515](https://github.com/matrix-org/synapse/issues/11515))
106 - Remove unnecessary `json.dumps` from `tests.rest.admin`. ([\#11461](https://github.com/matrix-org/synapse/issues/11461))
107 - Save the OpenID Connect session ID on login. ([\#11482](https://github.com/matrix-org/synapse/issues/11482))
108 - Update and clean up recently ported documentation pages. ([\#11466](https://github.com/matrix-org/synapse/issues/11466))
109
110
0111 Synapse 1.48.0 (2021-11-30)
1112 ===========================
2113
0 matrix-synapse-py3 (1.49.0) stable; urgency=medium
1
2 * New synapse release 1.49.0.
3
4 -- Synapse Packaging team <packages@matrix.org> Tue, 14 Dec 2021 12:39:46 +0000
5
6 matrix-synapse-py3 (1.49.0~rc1) stable; urgency=medium
7
8 * New synapse release 1.49.0~rc1.
9
10 -- Synapse Packaging team <packages@matrix.org> Tue, 07 Dec 2021 13:52:21 +0000
11
012 matrix-synapse-py3 (1.48.0) stable; urgency=medium
113
214 * New synapse release 1.48.0.
2020 # files to run the desired worker configuration. Will start supervisord.
2121 COPY ./docker/configure_workers_and_start.py /configure_workers_and_start.py
2222 ENTRYPOINT ["/configure_workers_and_start.py"]
23
24 HEALTHCHECK --start-period=5s --interval=15s --timeout=5s \
25 CMD /bin/sh /healthcheck.sh
0 #!/bin/sh
1 # This healthcheck script is designed to return OK when every
2 # host involved returns OK
3 {%- for healthcheck_url in healthcheck_urls %}
4 curl -fSs {{ healthcheck_url }} || exit 1
5 {%- endfor %}
473473
474474 # Determine the load-balancing upstreams to configure
475475 nginx_upstream_config = ""
476
477 # At the same time, prepare a list of internal endpoints to healthcheck
478 # starting with the main process which exists even if no workers do.
479 healthcheck_urls = ["http://localhost:8080/health"]
480
476481 for upstream_worker_type, upstream_worker_ports in nginx_upstreams.items():
477482 body = ""
478483 for port in upstream_worker_ports:
479484 body += " server localhost:%d;\n" % (port,)
485 healthcheck_urls.append("http://localhost:%d/health" % (port,))
480486
481487 # Add to the list of configured upstreams
482488 nginx_upstream_config += NGINX_UPSTREAM_CONFIG_BLOCK.format(
509515 worker_config=supervisord_config,
510516 )
511517
518 # healthcheck config
519 convert(
520 "/conf/healthcheck.sh.j2",
521 "/healthcheck.sh",
522 healthcheck_urls=healthcheck_urls,
523 )
524
512525 # Ensure the logging directory exists
513526 log_dir = data_dir + "/logs"
514527 if not os.path.exists(log_dir):
4343 - [Presence router callbacks](modules/presence_router_callbacks.md)
4444 - [Account validity callbacks](modules/account_validity_callbacks.md)
4545 - [Password auth provider callbacks](modules/password_auth_provider_callbacks.md)
46 - [Background update controller callbacks](modules/background_update_controller_callbacks.md)
4647 - [Porting a legacy module to the new interface](modules/porting_legacy_module.md)
4748 - [Workers](workers.md)
4849 - [Using `synctl` with Workers](synctl_workers.md)
6364 - [Statistics](admin_api/statistics.md)
6465 - [Users](admin_api/user_admin_api.md)
6566 - [Server Version](admin_api/version_api.md)
67 - [Federation](usage/administration/admin_api/federation.md)
6668 - [Manhole](manhole.md)
6769 - [Monitoring](metrics-howto.md)
70 - [Understanding Synapse Through Grafana Graphs](usage/administration/understanding_synapse_through_grafana_graphs.md)
71 - [Useful SQL for Admins](usage/administration/useful_sql_for_admins.md)
72 - [Database Maintenance Tools](usage/administration/database_maintenance_tools.md)
73 - [State Groups](usage/administration/state_groups.md)
6874 - [Request log format](usage/administration/request_log.md)
75 - [Admin FAQ](usage/administration/admin_faq.md)
6976 - [Scripts]()
7077
7178 # Development
93100
94101 # Other
95102 - [Dependency Deprecation Policy](deprecation_policy.md)
103 - [Running Synapse on a Single-Board Computer](other/running_synapse_on_single_board_computers.md)
3737 The forward extremities of a room are used as the `prev_events` when the next event is sent.
3838
3939
40 ## Backwards extremity
40 ## Backward extremity
4141
4242 The current marker of where we have backfilled up to and will generally be the
43 oldest-in-time events we know of in the DAG.
43 `prev_events` of the oldest-in-time events we have in the DAG. This gives a starting point when
44 backfilling history.
4445
45 This is an event where we haven't fetched all of the `prev_events` for.
46
47 Once we have fetched all of its `prev_events`, it's unmarked as a backwards
48 extremity (although we may have formed new backwards extremities from the prev
49 events during the backfilling process).
46 When we persist a non-outlier event, we clear it as a backward extremity and set
47 all of its `prev_events` as the new backward extremities if they aren't already
48 persisted in the `events` table.
5049
5150
5251 ## Outliers
5554 room at that point in the DAG yet.
5655
5756 We won't *necessarily* have the `prev_events` of an `outlier` in the database,
58 but it's entirely possible that we *might*. The status of whether we have all of
59 the `prev_events` is marked as a [backwards extremity](#backwards-extremity).
57 but it's entirely possible that we *might*.
6058
6159 For example, when we fetch the event auth chain or state for a given event, we
6260 mark all of those claimed auth events as outliers because we haven't done the
11
22 *Synapse implementation-specific details for the media repository*
33
4 The media repository is where attachments and avatar photos are stored.
5 It stores attachment content and thumbnails for media uploaded by local users.
6 It caches attachment content and thumbnails for media uploaded by remote users.
4 The media repository
5 * stores avatars, attachments and their thumbnails for media uploaded by local
6 users.
7 * caches avatars, attachments and their thumbnails for media uploaded by remote
8 users.
9 * caches resources and thumbnails used for
10 [URL previews](development/url_previews.md).
711
8 ## Storage
12 All media in Matrix can be identified by a unique
13 [MXC URI](https://spec.matrix.org/latest/client-server-api/#matrix-content-mxc-uris),
14 consisting of a server name and media ID:
15 ```
16 mxc://<server-name>/<media-id>
17 ```
918
10 Each item of media is assigned a `media_id` when it is uploaded.
11 The `media_id` is a randomly chosen, URL safe 24 character string.
19 ## Local Media
20 Synapse generates 24 character media IDs for content uploaded by local users.
21 These media IDs consist of upper and lowercase letters and are case-sensitive.
22 Other homeserver implementations may generate media IDs differently.
1223
13 Metadata such as the MIME type, upload time and length are stored in the
14 sqlite3 database indexed by `media_id`.
24 Local media is recorded in the `local_media_repository` table, which includes
25 metadata such as MIME types, upload times and file sizes.
26 Note that this table is shared by the URL cache, which has a different media ID
27 scheme.
1528
16 Content is stored on the filesystem under a `"local_content"` directory.
29 ### Paths
30 A file with media ID `aabbcccccccccccccccccccc` and its `128x96` `image/jpeg`
31 thumbnail, created by scaling, would be stored at:
32 ```
33 local_content/aa/bb/cccccccccccccccccccc
34 local_thumbnails/aa/bb/cccccccccccccccccccc/128-96-image-jpeg-scale
35 ```
1736
18 Thumbnails are stored under a `"local_thumbnails"` directory.
37 ## Remote Media
38 When media from a remote homeserver is requested from Synapse, it is assigned
39 a local `filesystem_id`, with the same format as locally-generated media IDs,
40 as described above.
1941
20 The item with `media_id` `"aabbccccccccdddddddddddd"` is stored under
21 `"local_content/aa/bb/ccccccccdddddddddddd"`. Its thumbnail with width
22 `128` and height `96` and type `"image/jpeg"` is stored under
23 `"local_thumbnails/aa/bb/ccccccccdddddddddddd/128-96-image-jpeg"`
42 A record of remote media is stored in the `remote_media_cache` table, which
43 can be used to map remote MXC URIs (server names and media IDs) to local
44 `filesystem_id`s.
2445
25 Remote content is cached under `"remote_content"` directory. Each item of
26 remote content is assigned a local `"filesystem_id"` to ensure that the
27 directory structure `"remote_content/server_name/aa/bb/ccccccccdddddddddddd"`
28 is appropriate. Thumbnails for remote content are stored under
29 `"remote_thumbnail/server_name/..."`
46 ### Paths
47 A file from `matrix.org` with `filesystem_id` `aabbcccccccccccccccccccc` and its
48 `128x96` `image/jpeg` thumbnail, created by scaling, would be stored at:
49 ```
50 remote_content/matrix.org/aa/bb/cccccccccccccccccccc
51 remote_thumbnail/matrix.org/aa/bb/cccccccccccccccccccc/128-96-image-jpeg-scale
52 ```
53 Older thumbnails may omit the thumbnailing method:
54 ```
55 remote_thumbnail/matrix.org/aa/bb/cccccccccccccccccccc/128-96-image-jpeg
56 ```
57
58 Note that `remote_thumbnail/` does not have an `s`.
59
60 ## URL Previews
61 See [URL Previews](development/url_previews.md) for documentation on the URL preview
62 process.
63
64 When generating previews for URLs, Synapse may download and cache various
65 resources, including images. These resources are assigned temporary media IDs
66 of the form `yyyy-mm-dd_aaaaaaaaaaaaaaaa`, where `yyyy-mm-dd` is the current
67 date and `aaaaaaaaaaaaaaaa` is a random sequence of 16 case-sensitive letters.
68
69 The metadata for these cached resources is stored in the
70 `local_media_repository` and `local_media_repository_url_cache` tables.
71
72 Resources for URL previews are deleted after a few days.
73
74 ### Paths
75 The file with media ID `yyyy-mm-dd_aaaaaaaaaaaaaaaa` and its `128x96`
76 `image/jpeg` thumbnail, created by scaling, would be stored at:
77 ```
78 url_cache/yyyy-mm-dd/aaaaaaaaaaaaaaaa
79 url_cache_thumbnails/yyyy-mm-dd/aaaaaaaaaaaaaaaa/128-96-image-jpeg-scale
80 ```
0 # Background update controller callbacks
1
2 Background update controller callbacks allow module developers to control (e.g. rate-limit)
3 how database background updates are run. A database background update is an operation
4 Synapse runs on its database in the background after it starts. It's usually used to run
5 database operations that would take too long if they were run at the same time as schema
6 updates (which are run on startup) and delay Synapse's startup too much: populating a
7 table with a big amount of data, adding an index on a big table, deleting superfluous data,
8 etc.
9
10 Background update controller callbacks can be registered using the module API's
11 `register_background_update_controller_callbacks` method. Only the first module (in order
12 of appearance in Synapse's configuration file) calling this method can register background
13 update controller callbacks, subsequent calls are ignored.
14
15 The available background update controller callbacks are:
16
17 ### `on_update`
18
19 _First introduced in Synapse v1.49.0_
20
21 ```python
22 def on_update(update_name: str, database_name: str, one_shot: bool) -> AsyncContextManager[int]
23 ```
24
25 Called when about to do an iteration of a background update. The module is given the name
26 of the update, the name of the database, and a flag to indicate whether the background
27 update will happen in one go and may take a long time (e.g. creating indices). If this last
28 argument is set to `False`, the update will be run in batches.
29
30 The module must return an async context manager. It will be entered before Synapse runs a
31 background update; this should return the desired duration of the iteration, in
32 milliseconds.
33
34 The context manager will be exited when the iteration completes. Note that the duration
35 returned by the context manager is a target, and an iteration may take substantially longer
36 or shorter. If the `one_shot` flag is set to `True`, the duration returned is ignored.
37
38 __Note__: Unlike most module callbacks in Synapse, this one is _synchronous_. This is
39 because asynchronous operations are expected to be run by the async context manager.
40
41 This callback is required when registering any other background update controller callback.
42
43 ### `default_batch_size`
44
45 _First introduced in Synapse v1.49.0_
46
47 ```python
48 async def default_batch_size(update_name: str, database_name: str) -> int
49 ```
50
51 Called before the first iteration of a background update, with the name of the update and
52 of the database. The module must return the number of elements to process in this first
53 iteration.
54
55 If this callback is not defined, Synapse will use a default value of 100.
56
57 ### `min_batch_size`
58
59 _First introduced in Synapse v1.49.0_
60
61 ```python
62 async def min_batch_size(update_name: str, database_name: str) -> int
63 ```
64
65 Called before running a new batch for a background update, with the name of the update and
66 of the database. The module must return an integer representing the minimum number of
67 elements to process in this iteration. This number must be at least 1, and is used to
68 ensure that progress is always made.
69
70 If this callback is not defined, Synapse will use a default value of 100.
7070 ## Registering a callback
7171
7272 Modules can use Synapse's module API to register callbacks. Callbacks are functions that
73 Synapse will call when performing specific actions. Callbacks must be asynchronous, and
74 are split in categories. A single module may implement callbacks from multiple categories,
75 and is under no obligation to implement all callbacks from the categories it registers
76 callbacks for.
73 Synapse will call when performing specific actions. Callbacks must be asynchronous (unless
74 specified otherwise), and are split in categories. A single module may implement callbacks
75 from multiple categories, and is under no obligation to implement all callbacks from the
76 categories it registers callbacks for.
7777
7878 Modules can register callbacks using one of the module API's `register_[...]_callbacks`
7979 methods. The callback functions are passed to these methods as keyword arguments, with
80 the callback name as the argument name and the function as its value. This is demonstrated
81 in the example below. A `register_[...]_callbacks` method exists for each category.
80 the callback name as the argument name and the function as its value. A
81 `register_[...]_callbacks` method exists for each category.
8282
8383 Callbacks for each category can be found on their respective page of the
8484 [Synapse documentation website](https://matrix-org.github.io/synapse).
8282
8383 ### Dex
8484
85 [Dex][dex-idp] is a simple, open-source, certified OpenID Connect Provider.
85 [Dex][dex-idp] is a simple, open-source OpenID Connect Provider.
8686 Although it is designed to help building a full-blown provider with an
8787 external database, it can be configured with static passwords in a config file.
8888
522522 email_template: "{{ user.email }}"
523523 ```
524524
525 ## Django OAuth Toolkit
525 ### Django OAuth Toolkit
526526
527527 [django-oauth-toolkit](https://github.com/jazzband/django-oauth-toolkit) is a
528528 Django application providing out of the box all the endpoints, data and logic
0 ## Summary of performance impact of running on resource constrained devices such as SBCs
1
2 I've been running my homeserver on a cubietruck at home now for some time and am often replying to statements like "you need loads of ram to join large rooms" with "it works fine for me". I thought it might be useful to curate a summary of the issues you're likely to run into to help as a scaling-down guide, maybe highlight these for development work or end up as documentation. It seems that once you get up to about 4x1.5GHz arm64 4GiB these issues are no longer a problem.
3
4 - **Platform**: 2x1GHz armhf 2GiB ram [Single-board computers](https://wiki.debian.org/CheapServerBoxHardware), SSD, postgres.
5
6 ### Presence
7
8 This is the main reason people have a poor matrix experience on resource constrained homeservers. Element web will frequently be saying the server is offline while the python process will be pegged at 100% cpu. This feature is used to tell when other users are active (have a client app in the foreground) and therefore more likely to respond, but requires a lot of network activity to maintain even when nobody is talking in a room.
9
10 ![Screenshot_2020-10-01_19-29-46](https://user-images.githubusercontent.com/71895/94848963-a47a3580-041c-11eb-8b6e-acb772b4259e.png)
11
12 While synapse does have some performance issues with presence [#3971](https://github.com/matrix-org/synapse/issues/3971), the fundamental problem is that this is an easy feature to implement for a centralised service at nearly no overhead, but federation makes it combinatorial [#8055](https://github.com/matrix-org/synapse/issues/8055). There is also a client-side config option which disables the UI and idle tracking [enable_presence_by_hs_url] to blacklist the largest instances but I didn't notice much difference, so I recommend disabling the feature entirely at the server level as well.
13
14 [enable_presence_by_hs_url]: https://github.com/vector-im/element-web/blob/v1.7.8/config.sample.json#L45
15
16 ### Joining
17
18 Joining a "large", federated room will initially fail with the below message in Element web, but waiting a while (10-60mins) and trying again will succeed without any issue. What counts as "large" is not message history, user count, connections to homeservers or even a simple count of the state events, it is instead how long the state resolution algorithm takes. However, each of those numbers are reasonable proxies, so we can use them as estimates since user count is one of the few things you see before joining.
19
20 ![Screenshot_2020-10-02_17-15-06](https://user-images.githubusercontent.com/71895/94945781-18771500-04d3-11eb-8419-83c2da73a341.png)
21
22 This is [#1211](https://github.com/matrix-org/synapse/issues/1211) and will also hopefully be mitigated by peeking [matrix-org/matrix-doc#2753](https://github.com/matrix-org/matrix-doc/pull/2753) so at least you don't need to wait for a join to complete before finding out if it's the kind of room you want. Note that you should first disable presence, otherwise it'll just make the situation worse [#3120](https://github.com/matrix-org/synapse/issues/3120). There is a lot of database interaction too, so make sure you've [migrated your data](../postgres.md) from the default sqlite to postgresql. Personally, I recommend patience - once the initial join is complete there's rarely any issues with actually interacting with the room, but if you like you can just block "large" rooms entirely.
23
24 ### Sessions
25
26 Anything that requires modifying the device list [#7721](https://github.com/matrix-org/synapse/issues/7721) will take a while to propagate, again taking the client "Offline" until it's complete. This includes signing in and out, editing the public name and verifying e2ee. The main mitigation I recommend is to keep long-running sessions open e.g. by using Firefox SSB "Use this site in App mode" or Chromium PWA "Install Element".
27
28 ### Recommended configuration
29
30 Put the below in a new file at /etc/matrix-synapse/conf.d/sbc.yaml to override the defaults in homeserver.yaml.
31
32 ```
33 # Set to false to disable presence tracking on this homeserver.
34 use_presence: false
35
36 # When this is enabled, the room "complexity" will be checked before a user
37 # joins a new remote room. If it is above the complexity limit, the server will
38 # disallow joining, or will instantly leave.
39 limit_remote_rooms:
40 # Uncomment to enable room complexity checking.
41 #enabled: true
42 complexity: 3.0
43
44 # Database configuration
45 database:
46 name: psycopg2
47 args:
48 user: matrix-synapse
49 # Generate a long, secure one with a password manager
50 password: hunter2
51 database: matrix-synapse
52 host: localhost
53 cp_min: 5
54 cp_max: 10
55 ```
56
57 Currently the complexity is measured by [current_state_events / 500](https://github.com/matrix-org/synapse/blob/v1.20.1/synapse/storage/databases/main/events_worker.py#L986). You can find join times and your most complex rooms like this:
58
59 ```
60 admin@homeserver:~$ zgrep '/client/r0/join/' /var/log/matrix-synapse/homeserver.log* | awk '{print $18, $25}' | sort --human-numeric-sort
61 29.922sec/-0.002sec /_matrix/client/r0/join/%23debian-fasttrack%3Apoddery.com
62 182.088sec/0.003sec /_matrix/client/r0/join/%23decentralizedweb-general%3Amatrix.org
63 911.625sec/-570.847sec /_matrix/client/r0/join/%23synapse%3Amatrix.org
64
65 admin@homeserver:~$ sudo --user postgres psql matrix-synapse --command 'select canonical_alias, joined_members, current_state_events from room_stats_state natural join room_stats_current where canonical_alias is not null order by current_state_events desc fetch first 5 rows only'
66 canonical_alias | joined_members | current_state_events
67 -------------------------------+----------------+----------------------
68 #_oftc_#debian:matrix.org | 871 | 52355
69 #matrix:matrix.org | 6379 | 10684
70 #irc:matrix.org | 461 | 3751
71 #decentralizedweb-general:matrix.org | 997 | 1509
72 #whatsapp:maunium.net | 554 | 854
73 ```
117117 Note that the appropriate values for those fields depend on the amount
118118 of free memory the database host has available.
119119
120 Additionally, admins of large deployments might want to consider using huge pages
121 to help manage memory, especially when using large values of `shared_buffers`. You
122 can read more about that [here](https://www.postgresql.org/docs/10/kernel-resources.html#LINUX-HUGE-PAGES).
120123
121124 ## Porting from SQLite
122125
12081208 #
12091209 #session_lifetime: 24h
12101210
1211 # Time that an access token remains valid for, if the session is
1212 # using refresh tokens.
1213 # For more information about refresh tokens, please see the manual.
1214 # Note that this only applies to clients which advertise support for
1215 # refresh tokens.
1216 #
1217 # Note also that this is calculated at login time and refresh time:
1218 # changes are not applied to existing sessions until they are refreshed.
1219 #
1220 # By default, this is 5 minutes.
1221 #
1222 #refreshable_access_token_lifetime: 5m
1223
1224 # Time that a refresh token remains valid for (provided that it is not
1225 # exchanged for another one first).
1226 # This option can be used to automatically log-out inactive sessions.
1227 # Please see the manual for more information.
1228 #
1229 # Note also that this is calculated at login time and refresh time:
1230 # changes are not applied to existing sessions until they are refreshed.
1231 #
1232 # By default, this is infinite.
1233 #
1234 #refresh_token_lifetime: 24h
1235
1236 # Time that an access token remains valid for, if the session is NOT
1237 # using refresh tokens.
1238 # Please note that not all clients support refresh tokens, so setting
1239 # this to a short value may be inconvenient for some users who will
1240 # then be logged out frequently.
1241 #
1242 # Note also that this is calculated at login time: changes are not applied
1243 # retrospectively to existing sessions for users that have already logged in.
1244 #
1245 # By default, this is infinite.
1246 #
1247 #nonrefreshable_access_token_lifetime: 24h
1248
12111249 # The user must provide all of the below types of 3PID when registering.
12121250 #
12131251 #registrations_require_3pid:
7070 * `sender_avatar_url`: the avatar URL (as a `mxc://` URL) for the event's
7171 sender
7272 * `sender_hash`: a hash of the user ID of the sender
73 * `msgtype`: the type of the message
74 * `body_text_html`: html representation of the message
75 * `body_text_plain`: plaintext representation of the message
76 * `image_url`: mxc url of an image, when "msgtype" is "m.image"
7377 * `link`: a `matrix.to` link to the room
78 * `avator_url`: url to the room's avator
7479 * `reason`: information on the event that triggered the email to be sent. It's an
7580 object with the following attributes:
7681 * `room_id`: the ID of the room the event was sent in
0 # Federation API
1
2 This API allows a server administrator to manage Synapse's federation with other homeservers.
3
4 Note: This API is new, experimental and "subject to change".
5
6 ## List of destinations
7
8 This API gets the current destination retry timing info for all remote servers.
9
10 The list contains all the servers with which the server federates,
11 regardless of whether an error occurred or not.
12 If an error occurs, it may take up to 20 minutes for the error to be displayed here,
13 as a complete retry must have failed.
14
15 The API is:
16
17 A standard request with no filtering:
18
19 ```
20 GET /_synapse/admin/v1/federation/destinations
21 ```
22
23 A response body like the following is returned:
24
25 ```json
26 {
27 "destinations":[
28 {
29 "destination": "matrix.org",
30 "retry_last_ts": 1557332397936,
31 "retry_interval": 3000000,
32 "failure_ts": 1557329397936,
33 "last_successful_stream_ordering": null
34 }
35 ],
36 "total": 1
37 }
38 ```
39
40 To paginate, check for `next_token` and if present, call the endpoint again
41 with `from` set to the value of `next_token`. This will return a new page.
42
43 If the endpoint does not return a `next_token` then there are no more destinations
44 to paginate through.
45
46 **Parameters**
47
48 The following query parameters are available:
49
50 - `from` - Offset in the returned list. Defaults to `0`.
51 - `limit` - Maximum amount of destinations to return. Defaults to `100`.
52 - `order_by` - The method in which to sort the returned list of destinations.
53 Valid values are:
54 - `destination` - Destinations are ordered alphabetically by remote server name.
55 This is the default.
56 - `retry_last_ts` - Destinations are ordered by time of last retry attempt in ms.
57 - `retry_interval` - Destinations are ordered by how long until next retry in ms.
58 - `failure_ts` - Destinations are ordered by when the server started failing in ms.
59 - `last_successful_stream_ordering` - Destinations are ordered by the stream ordering
60 of the most recent successfully-sent PDU.
61 - `dir` - Direction of room order. Either `f` for forwards or `b` for backwards. Setting
62 this value to `b` will reverse the above sort order. Defaults to `f`.
63
64 *Caution:* The database only has an index on the column `destination`.
65 This means that if a different sort order is used,
66 this can cause a large load on the database, especially for large environments.
67
68 **Response**
69
70 The following fields are returned in the JSON response body:
71
72 - `destinations` - An array of objects, each containing information about a destination.
73 Destination objects contain the following fields:
74 - `destination` - string - Name of the remote server to federate.
75 - `retry_last_ts` - integer - The last time Synapse tried and failed to reach the
76 remote server, in ms. This is `0` if the last attempt to communicate with the
77 remote server was successful.
78 - `retry_interval` - integer - How long since the last time Synapse tried to reach
79 the remote server before trying again, in ms. This is `0` if no further retrying occuring.
80 - `failure_ts` - nullable integer - The first time Synapse tried and failed to reach the
81 remote server, in ms. This is `null` if communication with the remote server has never failed.
82 - `last_successful_stream_ordering` - nullable integer - The stream ordering of the most
83 recent successfully-sent [PDU](understanding_synapse_through_grafana_graphs.md#federation)
84 to this destination, or `null` if this information has not been tracked yet.
85 - `next_token`: string representing a positive integer - Indication for pagination. See above.
86 - `total` - integer - Total number of destinations.
87
88 # Destination Details API
89
90 This API gets the retry timing info for a specific remote server.
91
92 The API is:
93
94 ```
95 GET /_synapse/admin/v1/federation/destinations/<destination>
96 ```
97
98 A response body like the following is returned:
99
100 ```json
101 {
102 "destination": "matrix.org",
103 "retry_last_ts": 1557332397936,
104 "retry_interval": 3000000,
105 "failure_ts": 1557329397936,
106 "last_successful_stream_ordering": null
107 }
108 ```
109
110 **Response**
111
112 The response fields are the same like in the `destinations` array in
113 [List of destinations](#list-of-destinations) response.
0 ## Admin FAQ
1
2 How do I become a server admin?
3 ---
4 If your server already has an admin account you should use the user admin API to promote other accounts to become admins. See [User Admin API](../../admin_api/user_admin_api.md#Change-whether-a-user-is-a-server-administrator-or-not)
5
6 If you don't have any admin accounts yet you won't be able to use the admin API so you'll have to edit the database manually. Manually editing the database is generally not recommended so once you have an admin account, use the admin APIs to make further changes.
7
8 ```sql
9 UPDATE users SET admin = 1 WHERE name = '@foo:bar.com';
10 ```
11 What servers are my server talking to?
12 ---
13 Run this sql query on your db:
14 ```sql
15 SELECT * FROM destinations;
16 ```
17
18 What servers are currently participating in this room?
19 ---
20 Run this sql query on your db:
21 ```sql
22 SELECT DISTINCT split_part(state_key, ':', 2)
23 FROM current_state_events AS c
24 INNER JOIN room_memberships AS m USING (room_id, event_id)
25 WHERE room_id = '!cURbafjkfsMDVwdRDQ:matrix.org' AND membership = 'join';
26 ```
27
28 What users are registered on my server?
29 ---
30 ```sql
31 SELECT NAME from users;
32 ```
33
34 Manually resetting passwords:
35 ---
36 See https://github.com/matrix-org/synapse/blob/master/README.rst#password-reset
37
38 I have a problem with my server. Can I just delete my database and start again?
39 ---
40 Deleting your database is unlikely to make anything better.
41
42 It's easy to make the mistake of thinking that you can start again from a clean slate by dropping your database, but things don't work like that in a federated network: lots of other servers have information about your server.
43
44 For example: other servers might think that you are in a room, your server will think that you are not, and you'll probably be unable to interact with that room in a sensible way ever again.
45
46 In general, there are better solutions to any problem than dropping the database. Come and seek help in https://matrix.to/#/#synapse:matrix.org.
47
48 There are two exceptions when it might be sensible to delete your database and start again:
49 * You have *never* joined any rooms which are federated with other servers. For instance, a local deployment which the outside world can't talk to.
50 * You are changing the `server_name` in the homeserver configuration. In effect this makes your server a completely new one from the point of view of the network, so in this case it makes sense to start with a clean database.
51 (In both cases you probably also want to clear out the media_store.)
52
53 I've stuffed up access to my room, how can I delete it to free up the alias?
54 ---
55 Using the following curl command:
56 ```
57 curl -H 'Authorization: Bearer <access-token>' -X DELETE https://matrix.org/_matrix/client/r0/directory/room/<room-alias>
58 ```
59 `<access-token>` - can be obtained in riot by looking in the riot settings, down the bottom is:
60 Access Token:\<click to reveal\>
61
62 `<room-alias>` - the room alias, eg. #my_room:matrix.org this possibly needs to be URL encoded also, for example %23my_room%3Amatrix.org
63
64 How can I find the lines corresponding to a given HTTP request in my homeserver log?
65 ---
66
67 Synapse tags each log line according to the HTTP request it is processing. When it finishes processing each request, it logs a line containing the words `Processed request: `. For example:
68
69 ```
70 2019-02-14 22:35:08,196 - synapse.access.http.8008 - 302 - INFO - GET-37 - ::1 - 8008 - {@richvdh:localhost} Processed request: 0.173sec/0.001sec (0.002sec, 0.000sec) (0.027sec/0.026sec/2) 687B 200 "GET /_matrix/client/r0/sync HTTP/1.1" "Mozilla/5.0 (X11; Linux x86_64) AppleWebKit/537.36 (KHTML, like Gecko) Chrome/69.0.3497.100 Safari/537.36" [0 dbevts]"
71 ```
72
73 Here we can see that the request has been tagged with `GET-37`. (The tag depends on the method of the HTTP request, so might start with `GET-`, `PUT-`, `POST-`, `OPTIONS-` or `DELETE-`.) So to find all lines corresponding to this request, we can do:
74
75 ```
76 grep 'GET-37' homeserver.log
77 ```
78
79 If you want to paste that output into a github issue or matrix room, please remember to surround it with triple-backticks (```) to make it legible (see https://help.github.com/en/articles/basic-writing-and-formatting-syntax#quoting-code).
80
81
82 What do all those fields in the 'Processed' line mean?
83 ---
84 See [Request log format](request_log.md).
85
86
87 What are the biggest rooms on my server?
88 ---
89
90 ```sql
91 SELECT s.canonical_alias, g.room_id, count(*) AS num_rows
92 FROM
93 state_groups_state AS g,
94 room_stats_state AS s
95 WHERE g.room_id = s.room_id
96 GROUP BY s.canonical_alias, g.room_id
97 ORDER BY num_rows desc
98 LIMIT 10;
99 ```
100
101 You can also use the [List Room API](../../admin_api/rooms.md#list-room-api)
102 and `order_by` `state_events`.
0 This blog post by Victor Berger explains how to use many of the tools listed on this page: https://levans.fr/shrink-synapse-database.html
1
2 # List of useful tools and scripts for maintenance Synapse database:
3
4 ## [Purge Remote Media API](../../admin_api/media_admin_api.md#purge-remote-media-api)
5 The purge remote media API allows server admins to purge old cached remote media.
6
7 ## [Purge Local Media API](../../admin_api/media_admin_api.md#delete-local-media)
8 This API deletes the *local* media from the disk of your own server.
9
10 ## [Purge History API](../../admin_api/purge_history_api.md)
11 The purge history API allows server admins to purge historic events from their database, reclaiming disk space.
12
13 ## [synapse-compress-state](https://github.com/matrix-org/rust-synapse-compress-state)
14 Tool for compressing (deduplicating) `state_groups_state` table.
15
16 ## [SQL for analyzing Synapse PostgreSQL database stats](useful_sql_for_admins.md)
17 Some easy SQL that reports useful stats about your Synapse database.
0 # How do State Groups work?
1
2 As a general rule, I encourage people who want to understand the deepest darkest secrets of the database schema to drop by #synapse-dev:matrix.org and ask questions.
3
4 However, one question that comes up frequently is that of how "state groups" work, and why the `state_groups_state` table gets so big, so here's an attempt to answer that question.
5
6 We need to be able to relatively quickly calculate the state of a room at any point in that room's history. In other words, we need to know the state of the room at each event in that room. This is done as follows:
7
8 A sequence of events where the state is the same are grouped together into a `state_group`; the mapping is recorded in `event_to_state_groups`. (Technically speaking, since a state event usually changes the state in the room, we are recording the state of the room *after* the given event id: which is to say, to a handwavey simplification, the first event in a state group is normally a state event, and others in the same state group are normally non-state-events.)
9
10 `state_groups` records, for each state group, the id of the room that we're looking at, and also the id of the first event in that group. (I'm not sure if that event id is used much in practice.)
11
12 Now, if we stored all the room state for each `state_group`, that would be a huge amount of data. Instead, for each state group, we normally store the difference between the state in that group and some other state group, and only occasionally (every 100 state changes or so) record the full state.
13
14 So, most state groups have an entry in `state_group_edges` (don't ask me why it's not a column in `state_groups`) which records the previous state group in the room, and `state_groups_state` records the differences in state since that previous state group.
15
16 A full state group just records the event id for each piece of state in the room at that point.
17
18 ## Known bugs with state groups
19
20 There are various reasons that we can end up creating many more state groups than we need: see https://github.com/matrix-org/synapse/issues/3364 for more details.
21
22 ## Compression tool
23
24 There is a tool at https://github.com/matrix-org/rust-synapse-compress-state which can compress the `state_groups_state` on a room by-room basis (essentially, it reduces the number of "full" state groups). This can result in dramatic reductions of the storage used.
0 ## Understanding Synapse through Grafana graphs
1
2 It is possible to monitor much of the internal state of Synapse using [Prometheus](https://prometheus.io)
3 metrics and [Grafana](https://grafana.com/).
4 A guide for configuring Synapse to provide metrics is available [here](../../metrics-howto.md)
5 and information on setting up Grafana is [here](https://github.com/matrix-org/synapse/tree/master/contrib/grafana).
6 In this setup, Prometheus will periodically scrape the information Synapse provides and
7 store a record of it over time. Grafana is then used as an interface to query and
8 present this information through a series of pretty graphs.
9
10 Once you have grafana set up, and assuming you're using [our grafana dashboard template](https://github.com/matrix-org/synapse/blob/master/contrib/grafana/synapse.json), look for the following graphs when debugging a slow/overloaded Synapse:
11
12 ## Message Event Send Time
13
14 ![image](https://user-images.githubusercontent.com/1342360/82239409-a1c8e900-9930-11ea-8081-e4614e0c63f4.png)
15
16 This, along with the CPU and Memory graphs, is a good way to check the general health of your Synapse instance. It represents how long it takes for a user on your homeserver to send a message.
17
18 ## Transaction Count and Transaction Duration
19
20 ![image](https://user-images.githubusercontent.com/1342360/82239985-8d392080-9931-11ea-80d0-843ab2f22e1e.png)
21
22 ![image](https://user-images.githubusercontent.com/1342360/82240050-ab068580-9931-11ea-98f1-f94671cbac9a.png)
23
24 These graphs show the database transactions that are occurring the most frequently, as well as those are that are taking the most amount of time to execute.
25
26 ![image](https://user-images.githubusercontent.com/1342360/82240192-e86b1300-9931-11ea-9aac-3e2c9bfa6fdc.png)
27
28 In the first graph, we can see obvious spikes corresponding to lots of `get_user_by_id` transactions. This would be useful information to figure out which part of the Synapse codebase is potentially creating a heavy load on the system. However, be sure to cross-reference this with Transaction Duration, which states that `get_users_by_id` is actually a very quick database transaction and isn't causing as much load as others, like `persist_events`:
29
30 ![image](https://user-images.githubusercontent.com/1342360/82240467-62030100-9932-11ea-8db9-917f2d977fe1.png)
31
32 Still, it's probably worth investigating why we're getting users from the database that often, and whether it's possible to reduce the amount of queries we make by adjusting our cache factor(s).
33
34 The `persist_events` transaction is responsible for saving new room events to the Synapse database, so can often show a high transaction duration.
35
36 ## Federation
37
38 The charts in the "Federation" section show information about incoming and outgoing federation requests. Federation data can be divided into two basic types:
39
40 - PDU (Persistent Data Unit) - room events: messages, state events (join/leave), etc. These are permanently stored in the database.
41 - EDU (Ephemeral Data Unit) - other data, which need not be stored permanently, such as read receipts, typing notifications.
42
43 The "Outgoing EDUs by type" chart shows the EDUs within outgoing federation requests by type: `m.device_list_update`, `m.direct_to_device`, `m.presence`, `m.receipt`, `m.typing`.
44
45 If you see a large number of `m.presence` EDUs and are having trouble with too much CPU load, you can disable `presence` in the Synapse config. See also [#3971](https://github.com/matrix-org/synapse/issues/3971).
46
47 ## Caches
48
49 ![image](https://user-images.githubusercontent.com/1342360/82240572-8b239180-9932-11ea-96ff-6b5f0e57ebe5.png)
50
51 ![image](https://user-images.githubusercontent.com/1342360/82240666-b8703f80-9932-11ea-86af-9f663988d8da.png)
52
53 This is quite a useful graph. It shows how many times Synapse attempts to retrieve a piece of data from a cache which the cache did not contain, thus resulting in a call to the database. We can see here that the `_get_joined_profile_from_event_id` cache is being requested a lot, and often the data we're after is not cached.
54
55 Cross-referencing this with the Eviction Rate graph, which shows that entries are being evicted from `_get_joined_profile_from_event_id` quite often:
56
57 ![image](https://user-images.githubusercontent.com/1342360/82240766-de95df80-9932-11ea-8c15-5acfc57c48da.png)
58
59 we should probably consider raising the size of that cache by raising its cache factor (a multiplier value for the size of an individual cache). Information on doing so is available [here](https://github.com/matrix-org/synapse/blob/ee421e524478c1ad8d43741c27379499c2f6135c/docs/sample_config.yaml#L608-L642) (note that the configuration of individual cache factors through the configuration file is available in Synapse v1.14.0+, whereas doing so through environment variables has been supported for a very long time). Note that this will increase Synapse's overall memory usage.
60
61 ## Forward Extremities
62
63 ![image](https://user-images.githubusercontent.com/1342360/82241440-13566680-9934-11ea-8b88-ba468db937ed.png)
64
65 Forward extremities are the leaf events at the end of a DAG in a room, aka events that have no children. The more that exist in a room, the more [state resolution](https://spec.matrix.org/v1.1/server-server-api/#room-state-resolution) that Synapse needs to perform (hint: it's an expensive operation). While Synapse has code to prevent too many of these existing at one time in a room, bugs can sometimes make them crop up again.
66
67 If a room has >10 forward extremities, it's worth checking which room is the culprit and potentially removing them using the SQL queries mentioned in [#1760](https://github.com/matrix-org/synapse/issues/1760).
68
69 ## Garbage Collection
70
71 ![image](https://user-images.githubusercontent.com/1342360/82241911-da6ac180-9934-11ea-9a0d-a311fe22acd0.png)
72
73 Large spikes in garbage collection times (bigger than shown here, I'm talking in the
74 multiple seconds range), can cause lots of problems in Synapse performance. It's more an
75 indicator of problems, and a symptom of other problems though, so check other graphs for what might be causing it.
76
77 ## Final Thoughts
78
79 If you're still having performance problems with your Synapse instance and you've
80 tried everything you can, it may just be a lack of system resources. Consider adding
81 more CPU and RAM, and make use of [worker mode](../../workers.md)
82 to make use of multiple CPU cores / multiple machines for your homeserver.
83
0 ## Some useful SQL queries for Synapse Admins
1
2 ## Size of full matrix db
3 `SELECT pg_size_pretty( pg_database_size( 'matrix' ) );`
4 ### Result example:
5 ```
6 pg_size_pretty
7 ----------------
8 6420 MB
9 (1 row)
10 ```
11 ## Show top 20 larger rooms by state events count
12 ```sql
13 SELECT r.name, s.room_id, s.current_state_events
14 FROM room_stats_current s
15 LEFT JOIN room_stats_state r USING (room_id)
16 ORDER BY current_state_events DESC
17 LIMIT 20;
18 ```
19
20 and by state_group_events count:
21 ```sql
22 SELECT rss.name, s.room_id, count(s.room_id) FROM state_groups_state s
23 LEFT JOIN room_stats_state rss USING (room_id)
24 GROUP BY s.room_id, rss.name
25 ORDER BY count(s.room_id) DESC
26 LIMIT 20;
27 ```
28 plus same, but with join removed for performance reasons:
29 ```sql
30 SELECT s.room_id, count(s.room_id) FROM state_groups_state s
31 GROUP BY s.room_id
32 ORDER BY count(s.room_id) DESC
33 LIMIT 20;
34 ```
35
36 ## Show top 20 larger tables by row count
37 ```sql
38 SELECT relname, n_live_tup as rows
39 FROM pg_stat_user_tables
40 ORDER BY n_live_tup DESC
41 LIMIT 20;
42 ```
43 This query is quick, but may be very approximate, for exact number of rows use `SELECT COUNT(*) FROM <table_name>`.
44 ### Result example:
45 ```
46 state_groups_state - 161687170
47 event_auth - 8584785
48 event_edges - 6995633
49 event_json - 6585916
50 event_reference_hashes - 6580990
51 events - 6578879
52 received_transactions - 5713989
53 event_to_state_groups - 4873377
54 stream_ordering_to_exterm - 4136285
55 current_state_delta_stream - 3770972
56 event_search - 3670521
57 state_events - 2845082
58 room_memberships - 2785854
59 cache_invalidation_stream - 2448218
60 state_groups - 1255467
61 state_group_edges - 1229849
62 current_state_events - 1222905
63 users_in_public_rooms - 364059
64 device_lists_stream - 326903
65 user_directory_search - 316433
66 ```
67
68 ## Show top 20 rooms by new events count in last 1 day:
69 ```sql
70 SELECT e.room_id, r.name, COUNT(e.event_id) cnt FROM events e
71 LEFT JOIN room_stats_state r USING (room_id)
72 WHERE e.origin_server_ts >= DATE_PART('epoch', NOW() - INTERVAL '1 day') * 1000 GROUP BY e.room_id, r.name ORDER BY cnt DESC LIMIT 20;
73 ```
74
75 ## Show top 20 users on homeserver by sent events (messages) at last month:
76 ```sql
77 SELECT user_id, SUM(total_events)
78 FROM user_stats_historical
79 WHERE TO_TIMESTAMP(end_ts/1000) AT TIME ZONE 'UTC' > date_trunc('day', now() - interval '1 month')
80 GROUP BY user_id
81 ORDER BY SUM(total_events) DESC
82 LIMIT 20;
83 ```
84
85 ## Show last 100 messages from needed user, with room names:
86 ```sql
87 SELECT e.room_id, r.name, e.event_id, e.type, e.content, j.json FROM events e
88 LEFT JOIN event_json j USING (room_id)
89 LEFT JOIN room_stats_state r USING (room_id)
90 WHERE sender = '@LOGIN:example.com'
91 AND e.type = 'm.room.message'
92 ORDER BY stream_ordering DESC
93 LIMIT 100;
94 ```
95
96 ## Show top 20 larger tables by storage size
97 ```sql
98 SELECT nspname || '.' || relname AS "relation",
99 pg_size_pretty(pg_total_relation_size(C.oid)) AS "total_size"
100 FROM pg_class C
101 LEFT JOIN pg_namespace N ON (N.oid = C.relnamespace)
102 WHERE nspname NOT IN ('pg_catalog', 'information_schema')
103 AND C.relkind <> 'i'
104 AND nspname !~ '^pg_toast'
105 ORDER BY pg_total_relation_size(C.oid) DESC
106 LIMIT 20;
107 ```
108 ### Result example:
109 ```
110 public.state_groups_state - 27 GB
111 public.event_json - 9855 MB
112 public.events - 3675 MB
113 public.event_edges - 3404 MB
114 public.received_transactions - 2745 MB
115 public.event_reference_hashes - 1864 MB
116 public.event_auth - 1775 MB
117 public.stream_ordering_to_exterm - 1663 MB
118 public.event_search - 1370 MB
119 public.room_memberships - 1050 MB
120 public.event_to_state_groups - 948 MB
121 public.current_state_delta_stream - 711 MB
122 public.state_events - 611 MB
123 public.presence_stream - 530 MB
124 public.current_state_events - 525 MB
125 public.cache_invalidation_stream - 466 MB
126 public.receipts_linearized - 279 MB
127 public.state_groups - 160 MB
128 public.device_lists_remote_cache - 124 MB
129 public.state_group_edges - 122 MB
130 ```
131
132 ## Show rooms with names, sorted by events in this rooms
133 `echo "select event_json.room_id,room_stats_state.name from event_json,room_stats_state where room_stats_state.room_id=event_json.room_id" | psql synapse | sort | uniq -c | sort -n`
134 ### Result example:
135 ```
136 9459 !FPUfgzXYWTKgIrwKxW:matrix.org | This Week in Matrix
137 9459 !FPUfgzXYWTKgIrwKxW:matrix.org | This Week in Matrix (TWIM)
138 17799 !iDIOImbmXxwNngznsa:matrix.org | Linux in Russian
139 18739 !GnEEPYXUhoaHbkFBNX:matrix.org | Riot Android
140 23373 !QtykxKocfZaZOUrTwp:matrix.org | Matrix HQ
141 39504 !gTQfWzbYncrtNrvEkB:matrix.org | ru.[matrix]
142 43601 !iNmaIQExDMeqdITdHH:matrix.org | Riot
143 43601 !iNmaIQExDMeqdITdHH:matrix.org | Riot Web/Desktop
144 ```
145
146 ## Lookup room state info by list of room_id
147 ```sql
148 SELECT rss.room_id, rss.name, rss.canonical_alias, rss.topic, rss.encryption, rsc.joined_members, rsc.local_users_in_room, rss.join_rules
149 FROM room_stats_state rss
150 LEFT JOIN room_stats_current rsc USING (room_id)
151 WHERE room_id IN (WHERE room_id IN (
152 '!OGEhHVWSdvArJzumhm:matrix.org',
153 '!YTvKGNlinIzlkMTVRl:matrix.org'
154 )
155 ```
209209 ^/_matrix/federation/v1/get_groups_publicised$
210210 ^/_matrix/key/v2/query
211211 ^/_matrix/federation/unstable/org.matrix.msc2946/spaces/
212 ^/_matrix/federation/unstable/org.matrix.msc2946/hierarchy/
212 ^/_matrix/federation/(v1|unstable/org.matrix.msc2946)/hierarchy/
213213
214214 # Inbound federation transaction request
215215 ^/_matrix/federation/v1/send/
222222 ^/_matrix/client/(api/v1|r0|v3|unstable)/rooms/.*/members$
223223 ^/_matrix/client/(api/v1|r0|v3|unstable)/rooms/.*/state$
224224 ^/_matrix/client/unstable/org.matrix.msc2946/rooms/.*/spaces$
225 ^/_matrix/client/unstable/org.matrix.msc2946/rooms/.*/hierarchy$
225 ^/_matrix/client/(v1|unstable/org.matrix.msc2946)/rooms/.*/hierarchy$
226226 ^/_matrix/client/unstable/im.nheko.summary/rooms/.*/summary$
227227 ^/_matrix/client/(api/v1|r0|v3|unstable)/account/3pid$
228228 ^/_matrix/client/(api/v1|r0|v3|unstable)/devices$
3232 |synapse/storage/databases/main/event_federation.py
3333 |synapse/storage/databases/main/event_push_actions.py
3434 |synapse/storage/databases/main/events_bg_updates.py
35 |synapse/storage/databases/main/events_worker.py
3635 |synapse/storage/databases/main/group_server.py
3736 |synapse/storage/databases/main/metrics.py
3837 |synapse/storage/databases/main/monthly_active_users.py
8685 |tests/push/test_presentable_names.py
8786 |tests/push/test_push_rule_evaluator.py
8887 |tests/rest/admin/test_admin.py
89 |tests/rest/admin/test_device.py
90 |tests/rest/admin/test_media.py
91 |tests/rest/admin/test_server_notice.py
9288 |tests/rest/admin/test_user.py
9389 |tests/rest/admin/test_username_available.py
9490 |tests/rest/client/test_account.py
111107 |tests/server_notices/test_resource_limits_server_notices.py
112108 |tests/state/test_v2.py
113109 |tests/storage/test_account_data.py
114 |tests/storage/test_appservice.py
115110 |tests/storage/test_background_update.py
116111 |tests/storage/test_base.py
117112 |tests/storage/test_client_ips.py
124119 |tests/test_server.py
125120 |tests/test_state.py
126121 |tests/test_terms_auth.py
127 |tests/test_visibility.py
128122 |tests/unittest.py
129123 |tests/util/caches/test_cached_call.py
130124 |tests/util/caches/test_deferred_cache.py
159153 [mypy-synapse.events.*]
160154 disallow_untyped_defs = True
161155
156 [mypy-synapse.federation.*]
157 disallow_untyped_defs = True
158
159 [mypy-synapse.federation.transport.client]
160 disallow_untyped_defs = False
161
162162 [mypy-synapse.handlers.*]
163163 disallow_untyped_defs = True
164164
165165 [mypy-synapse.metrics.*]
166166 disallow_untyped_defs = True
167167
168 [mypy-synapse.module_api.*]
169 disallow_untyped_defs = True
170
168171 [mypy-synapse.push.*]
169172 disallow_untyped_defs = True
170173
181184 disallow_untyped_defs = True
182185
183186 [mypy-synapse.storage.databases.main.directory]
187 disallow_untyped_defs = True
188
189 [mypy-synapse.storage.databases.main.events_worker]
184190 disallow_untyped_defs = True
185191
186192 [mypy-synapse.storage.databases.main.room_batch]
218224
219225 [mypy-tests.rest.client.test_directory]
220226 disallow_untyped_defs = True
227
228 [mypy-tests.federation.transport.test_client]
229 disallow_untyped_defs = True
230
221231
222232 ;; Dependencies without annotations
223233 ;; Before ignoring a module, check to see if type stubs are available.
6464 fi
6565
6666 # Run the tests!
67 go test -v -tags synapse_blacklist,msc2946,msc3083,msc2403 -count=1 $EXTRA_COMPLEMENT_ARGS ./tests/...
67 go test -v -tags synapse_blacklist,msc2403 -count=1 $EXTRA_COMPLEMENT_ARGS ./tests/...
1414 # See the License for the specific language governing permissions and
1515 # limitations under the License.
1616
17
18 """
19 Script for signing and sending federation requests.
20
21 Some tips on doing the join dance with this:
22
23 room_id=...
24 user_id=...
25
26 # make_join
27 federation_client.py "/_matrix/federation/v1/make_join/$room_id/$user_id?ver=5" > make_join.json
28
29 # sign
30 jq -M .event make_join.json | sign_json --sign-event-room-version=$(jq -r .room_version make_join.json) -o signed-join.json
31
32 # send_join
33 federation_client.py -X PUT "/_matrix/federation/v2/send_join/$room_id/x" --body $(<signed-join.json) > send_join.json
34 """
35
1736 import argparse
1837 import base64
1938 import json
2121 from signedjson.key import read_signing_keys
2222 from signedjson.sign import sign_json
2323
24 from synapse.api.room_versions import KNOWN_ROOM_VERSIONS
25 from synapse.crypto.event_signing import add_hashes_and_signatures
2426 from synapse.util import json_encoder
2527
2628
6466 "Path to synapse config file, from which the server name and/or signing "
6567 "key path will be read. Ignored if --server-name and --signing-key(-path) "
6668 "are both given."
69 ),
70 )
71
72 parser.add_argument(
73 "--sign-event-room-version",
74 type=str,
75 help=(
76 "Sign the JSON as an event for the given room version, rather than raw JSON. "
77 "This means that we will add a 'hashes' object, and redact the event before "
78 "signing."
6779 ),
6880 )
6981
115127 print("Input json was not an object", file=sys.stderr)
116128 sys.exit(1)
117129
118 sign_json(obj, args.server_name, keys[0])
130 if args.sign_event_room_version:
131 room_version = KNOWN_ROOM_VERSIONS.get(args.sign_event_room_version)
132 if not room_version:
133 print(
134 f"Unknown room version {args.sign_event_room_version}", file=sys.stderr
135 )
136 sys.exit(1)
137 add_hashes_and_signatures(room_version, obj, args.server_name, keys[0])
138 else:
139 sign_json(obj, args.server_name, keys[0])
140
119141 for c in json_encoder.iterencode(obj):
120142 args.output.write(c)
121143 args.output.write("\n")
118118 # Tests assume that all optional dependencies are installed.
119119 #
120120 # parameterized_class decorator was introduced in parameterized 0.7.0
121 CONDITIONAL_REQUIREMENTS["test"] = ["parameterized>=0.7.0"]
121 #
122 # We use `mock` library as that backports `AsyncMock` to Python 3.6
123 CONDITIONAL_REQUIREMENTS["test"] = ["parameterized>=0.7.0", "mock>=4.0.0"]
122124
123125 CONDITIONAL_REQUIREMENTS["dev"] = (
124126 CONDITIONAL_REQUIREMENTS["lint"]
149151 long_description=long_description,
150152 long_description_content_type="text/x-rst",
151153 python_requires="~=3.6",
154 entry_points={
155 "console_scripts": [
156 "synapse_homeserver = synapse.app.homeserver:main",
157 "synapse_worker = synapse.app.generic_worker:main",
158 ]
159 },
152160 classifiers=[
153161 "Development Status :: 5 - Production/Stable",
154162 "Topic :: Communications :: Chat",
4646 except ImportError:
4747 pass
4848
49 __version__ = "1.48.0"
49 __version__ = "1.49.0"
5050
5151 if bool(os.environ.get("SYNAPSE_TEST_PATCH_LOG_CONTEXTS", False)):
5252 # We import here so that we don't have to install a bunch of deps when
1616
1717 """Contains constants from the specification."""
1818
19 from typing_extensions import Final
20
1921 # the max size of a (canonical-json-encoded) event
2022 MAX_PDU_SIZE = 65536
2123
3840
3941 """Represents the membership states of a user in a room."""
4042
41 INVITE = "invite"
42 JOIN = "join"
43 KNOCK = "knock"
44 LEAVE = "leave"
45 BAN = "ban"
46 LIST = (INVITE, JOIN, KNOCK, LEAVE, BAN)
43 INVITE: Final = "invite"
44 JOIN: Final = "join"
45 KNOCK: Final = "knock"
46 LEAVE: Final = "leave"
47 BAN: Final = "ban"
48 LIST: Final = (INVITE, JOIN, KNOCK, LEAVE, BAN)
4749
4850
4951 class PresenceState:
5052 """Represents the presence state of a user."""
5153
52 OFFLINE = "offline"
53 UNAVAILABLE = "unavailable"
54 ONLINE = "online"
55 BUSY = "org.matrix.msc3026.busy"
54 OFFLINE: Final = "offline"
55 UNAVAILABLE: Final = "unavailable"
56 ONLINE: Final = "online"
57 BUSY: Final = "org.matrix.msc3026.busy"
5658
5759
5860 class JoinRules:
59 PUBLIC = "public"
60 KNOCK = "knock"
61 INVITE = "invite"
62 PRIVATE = "private"
61 PUBLIC: Final = "public"
62 KNOCK: Final = "knock"
63 INVITE: Final = "invite"
64 PRIVATE: Final = "private"
6365 # As defined for MSC3083.
64 RESTRICTED = "restricted"
66 RESTRICTED: Final = "restricted"
6567
6668
6769 class RestrictedJoinRuleTypes:
6870 """Understood types for the allow rules in restricted join rules."""
6971
70 ROOM_MEMBERSHIP = "m.room_membership"
72 ROOM_MEMBERSHIP: Final = "m.room_membership"
7173
7274
7375 class LoginType:
74 PASSWORD = "m.login.password"
75 EMAIL_IDENTITY = "m.login.email.identity"
76 MSISDN = "m.login.msisdn"
77 RECAPTCHA = "m.login.recaptcha"
78 TERMS = "m.login.terms"
79 SSO = "m.login.sso"
80 DUMMY = "m.login.dummy"
81 REGISTRATION_TOKEN = "org.matrix.msc3231.login.registration_token"
76 PASSWORD: Final = "m.login.password"
77 EMAIL_IDENTITY: Final = "m.login.email.identity"
78 MSISDN: Final = "m.login.msisdn"
79 RECAPTCHA: Final = "m.login.recaptcha"
80 TERMS: Final = "m.login.terms"
81 SSO: Final = "m.login.sso"
82 DUMMY: Final = "m.login.dummy"
83 REGISTRATION_TOKEN: Final = "org.matrix.msc3231.login.registration_token"
8284
8385
8486 # This is used in the `type` parameter for /register when called by
8587 # an appservice to register a new user.
86 APP_SERVICE_REGISTRATION_TYPE = "m.login.application_service"
88 APP_SERVICE_REGISTRATION_TYPE: Final = "m.login.application_service"
8789
8890
8991 class EventTypes:
90 Member = "m.room.member"
91 Create = "m.room.create"
92 Tombstone = "m.room.tombstone"
93 JoinRules = "m.room.join_rules"
94 PowerLevels = "m.room.power_levels"
95 Aliases = "m.room.aliases"
96 Redaction = "m.room.redaction"
97 ThirdPartyInvite = "m.room.third_party_invite"
98 RelatedGroups = "m.room.related_groups"
99
100 RoomHistoryVisibility = "m.room.history_visibility"
101 CanonicalAlias = "m.room.canonical_alias"
102 Encrypted = "m.room.encrypted"
103 RoomAvatar = "m.room.avatar"
104 RoomEncryption = "m.room.encryption"
105 GuestAccess = "m.room.guest_access"
92 Member: Final = "m.room.member"
93 Create: Final = "m.room.create"
94 Tombstone: Final = "m.room.tombstone"
95 JoinRules: Final = "m.room.join_rules"
96 PowerLevels: Final = "m.room.power_levels"
97 Aliases: Final = "m.room.aliases"
98 Redaction: Final = "m.room.redaction"
99 ThirdPartyInvite: Final = "m.room.third_party_invite"
100 RelatedGroups: Final = "m.room.related_groups"
101
102 RoomHistoryVisibility: Final = "m.room.history_visibility"
103 CanonicalAlias: Final = "m.room.canonical_alias"
104 Encrypted: Final = "m.room.encrypted"
105 RoomAvatar: Final = "m.room.avatar"
106 RoomEncryption: Final = "m.room.encryption"
107 GuestAccess: Final = "m.room.guest_access"
106108
107109 # These are used for validation
108 Message = "m.room.message"
109 Topic = "m.room.topic"
110 Name = "m.room.name"
111
112 ServerACL = "m.room.server_acl"
113 Pinned = "m.room.pinned_events"
114
115 Retention = "m.room.retention"
116
117 Dummy = "org.matrix.dummy_event"
118
119 SpaceChild = "m.space.child"
120 SpaceParent = "m.space.parent"
121
122 MSC2716_INSERTION = "org.matrix.msc2716.insertion"
123 MSC2716_BATCH = "org.matrix.msc2716.batch"
124 MSC2716_MARKER = "org.matrix.msc2716.marker"
110 Message: Final = "m.room.message"
111 Topic: Final = "m.room.topic"
112 Name: Final = "m.room.name"
113
114 ServerACL: Final = "m.room.server_acl"
115 Pinned: Final = "m.room.pinned_events"
116
117 Retention: Final = "m.room.retention"
118
119 Dummy: Final = "org.matrix.dummy_event"
120
121 SpaceChild: Final = "m.space.child"
122 SpaceParent: Final = "m.space.parent"
123
124 MSC2716_INSERTION: Final = "org.matrix.msc2716.insertion"
125 MSC2716_BATCH: Final = "org.matrix.msc2716.batch"
126 MSC2716_MARKER: Final = "org.matrix.msc2716.marker"
125127
126128
127129 class ToDeviceEventTypes:
128 RoomKeyRequest = "m.room_key_request"
130 RoomKeyRequest: Final = "m.room_key_request"
129131
130132
131133 class DeviceKeyAlgorithms:
132134 """Spec'd algorithms for the generation of per-device keys"""
133135
134 ED25519 = "ed25519"
135 CURVE25519 = "curve25519"
136 SIGNED_CURVE25519 = "signed_curve25519"
136 ED25519: Final = "ed25519"
137 CURVE25519: Final = "curve25519"
138 SIGNED_CURVE25519: Final = "signed_curve25519"
137139
138140
139141 class EduTypes:
140 Presence = "m.presence"
142 Presence: Final = "m.presence"
141143
142144
143145 class RejectedReason:
144 AUTH_ERROR = "auth_error"
146 AUTH_ERROR: Final = "auth_error"
145147
146148
147149 class RoomCreationPreset:
148 PRIVATE_CHAT = "private_chat"
149 PUBLIC_CHAT = "public_chat"
150 TRUSTED_PRIVATE_CHAT = "trusted_private_chat"
150 PRIVATE_CHAT: Final = "private_chat"
151 PUBLIC_CHAT: Final = "public_chat"
152 TRUSTED_PRIVATE_CHAT: Final = "trusted_private_chat"
151153
152154
153155 class ThirdPartyEntityKind:
154 USER = "user"
155 LOCATION = "location"
156
157
158 ServerNoticeMsgType = "m.server_notice"
159 ServerNoticeLimitReached = "m.server_notice.usage_limit_reached"
156 USER: Final = "user"
157 LOCATION: Final = "location"
158
159
160 ServerNoticeMsgType: Final = "m.server_notice"
161 ServerNoticeLimitReached: Final = "m.server_notice.usage_limit_reached"
160162
161163
162164 class UserTypes:
164166 'admin' and 'guest' users should also be UserTypes. Normal users are type None
165167 """
166168
167 SUPPORT = "support"
168 BOT = "bot"
169 ALL_USER_TYPES = (SUPPORT, BOT)
169 SUPPORT: Final = "support"
170 BOT: Final = "bot"
171 ALL_USER_TYPES: Final = (SUPPORT, BOT)
170172
171173
172174 class RelationTypes:
173175 """The types of relations known to this server."""
174176
175 ANNOTATION = "m.annotation"
176 REPLACE = "m.replace"
177 REFERENCE = "m.reference"
178 THREAD = "io.element.thread"
177 ANNOTATION: Final = "m.annotation"
178 REPLACE: Final = "m.replace"
179 REFERENCE: Final = "m.reference"
180 THREAD: Final = "io.element.thread"
179181
180182
181183 class LimitBlockingTypes:
182184 """Reasons that a server may be blocked"""
183185
184 MONTHLY_ACTIVE_USER = "monthly_active_user"
185 HS_DISABLED = "hs_disabled"
186 MONTHLY_ACTIVE_USER: Final = "monthly_active_user"
187 HS_DISABLED: Final = "hs_disabled"
186188
187189
188190 class EventContentFields:
189191 """Fields found in events' content, regardless of type."""
190192
191193 # Labels for the event, cf https://github.com/matrix-org/matrix-doc/pull/2326
192 LABELS = "org.matrix.labels"
194 LABELS: Final = "org.matrix.labels"
193195
194196 # Timestamp to delete the event after
195197 # cf https://github.com/matrix-org/matrix-doc/pull/2228
196 SELF_DESTRUCT_AFTER = "org.matrix.self_destruct_after"
198 SELF_DESTRUCT_AFTER: Final = "org.matrix.self_destruct_after"
197199
198200 # cf https://github.com/matrix-org/matrix-doc/pull/1772
199 ROOM_TYPE = "type"
201 ROOM_TYPE: Final = "type"
200202
201203 # Whether a room can federate.
202 FEDERATE = "m.federate"
204 FEDERATE: Final = "m.federate"
203205
204206 # The creator of the room, as used in `m.room.create` events.
205 ROOM_CREATOR = "creator"
207 ROOM_CREATOR: Final = "creator"
206208
207209 # Used in m.room.guest_access events.
208 GUEST_ACCESS = "guest_access"
210 GUEST_ACCESS: Final = "guest_access"
209211
210212 # Used on normal messages to indicate they were historically imported after the fact
211 MSC2716_HISTORICAL = "org.matrix.msc2716.historical"
213 MSC2716_HISTORICAL: Final = "org.matrix.msc2716.historical"
212214 # For "insertion" events to indicate what the next batch ID should be in
213215 # order to connect to it
214 MSC2716_NEXT_BATCH_ID = "org.matrix.msc2716.next_batch_id"
216 MSC2716_NEXT_BATCH_ID: Final = "org.matrix.msc2716.next_batch_id"
215217 # Used on "batch" events to indicate which insertion event it connects to
216 MSC2716_BATCH_ID = "org.matrix.msc2716.batch_id"
218 MSC2716_BATCH_ID: Final = "org.matrix.msc2716.batch_id"
217219 # For "marker" events
218 MSC2716_MARKER_INSERTION = "org.matrix.msc2716.marker.insertion"
220 MSC2716_MARKER_INSERTION: Final = "org.matrix.msc2716.marker.insertion"
219221
220222 # The authorising user for joining a restricted room.
221 AUTHORISING_USER = "join_authorised_via_users_server"
223 AUTHORISING_USER: Final = "join_authorised_via_users_server"
222224
223225
224226 class RoomTypes:
225227 """Understood values of the room_type field of m.room.create events."""
226228
227 SPACE = "m.space"
229 SPACE: Final = "m.space"
228230
229231
230232 class RoomEncryptionAlgorithms:
231 MEGOLM_V1_AES_SHA2 = "m.megolm.v1.aes-sha2"
232 DEFAULT = MEGOLM_V1_AES_SHA2
233 MEGOLM_V1_AES_SHA2: Final = "m.megolm.v1.aes-sha2"
234 DEFAULT: Final = MEGOLM_V1_AES_SHA2
233235
234236
235237 class AccountDataTypes:
236 DIRECT = "m.direct"
237 IGNORED_USER_LIST = "m.ignored_user_list"
238 DIRECT: Final = "m.direct"
239 IGNORED_USER_LIST: Final = "m.ignored_user_list"
238240
239241
240242 class HistoryVisibility:
241 INVITED = "invited"
242 JOINED = "joined"
243 SHARED = "shared"
244 WORLD_READABLE = "world_readable"
243 INVITED: Final = "invited"
244 JOINED: Final = "joined"
245 SHARED: Final = "shared"
246 WORLD_READABLE: Final = "world_readable"
245247
246248
247249 class GuestAccess:
248 CAN_JOIN = "can_join"
250 CAN_JOIN: Final = "can_join"
249251 # anything that is not "can_join" is considered "forbidden", but for completeness:
250 FORBIDDEN = "forbidden"
252 FORBIDDEN: Final = "forbidden"
251253
252254
253255 class ReadReceiptEventFields:
254 MSC2285_HIDDEN = "org.matrix.msc2285.hidden"
256 MSC2285_HIDDEN: Final = "org.matrix.msc2285.hidden"
3131 Iterable,
3232 List,
3333 NoReturn,
34 Optional,
3435 Tuple,
3536 cast,
3637 )
128129 def start_reactor(
129130 appname: str,
130131 soft_file_limit: int,
131 gc_thresholds: Tuple[int, int, int],
132 gc_thresholds: Optional[Tuple[int, int, int]],
132133 pid_file: str,
133134 daemonize: bool,
134135 print_pidfile: bool,
112112 )
113113 from synapse.storage.databases.main.presence import PresenceStore
114114 from synapse.storage.databases.main.room import RoomWorkerStore
115 from synapse.storage.databases.main.room_batch import RoomBatchStore
115116 from synapse.storage.databases.main.search import SearchStore
116117 from synapse.storage.databases.main.session import SessionStore
117118 from synapse.storage.databases.main.stats import StatsStore
239240 SlavedEventStore,
240241 SlavedKeyStore,
241242 RoomWorkerStore,
243 RoomBatchStore,
242244 DirectoryStore,
243245 SlavedApplicationServiceStore,
244246 SlavedRegistrationStore,
502504 _base.start_worker_reactor("synapse-generic-worker", config)
503505
504506
505 if __name__ == "__main__":
507 def main() -> None:
506508 with LoggingContext("main"):
507509 start(sys.argv[1:])
510
511
512 if __name__ == "__main__":
513 main()
193193 {
194194 "/_matrix/client/api/v1": client_resource,
195195 "/_matrix/client/r0": client_resource,
196 "/_matrix/client/v1": client_resource,
196197 "/_matrix/client/v3": client_resource,
197198 "/_matrix/client/unstable": client_resource,
198199 "/_matrix/client/v2_alpha": client_resource,
356357 # generating config files and shouldn't try to continue.
357358 sys.exit(0)
358359
360 if config.worker.worker_app:
361 raise ConfigError(
362 "You have specified `worker_app` in the config but are attempting to start a non-worker "
363 "instance. Please use `python -m synapse.app.generic_worker` instead (or remove the option if this is the main process)."
364 )
365 sys.exit(1)
366
359367 events.USE_FROZEN_DICTS = config.server.use_frozen_dicts
360368 synapse.util.caches.TRACK_MEMORY_USAGE = config.caches.track_memory_usage
361369
1212 # limitations under the License.
1313 import logging
1414 import re
15 from enum import Enum
1516 from typing import TYPE_CHECKING, Iterable, List, Match, Optional
1617
1718 from synapse.api.constants import EventTypes
2627 logger = logging.getLogger(__name__)
2728
2829
29 class ApplicationServiceState:
30 class ApplicationServiceState(Enum):
3031 DOWN = "down"
3132 UP = "up"
3233
1212 # See the License for the specific language governing permissions and
1313 # limitations under the License.
1414 import sys
15 from typing import List
1516
1617 from synapse.config._base import ConfigError
1718 from synapse.config.homeserver import HomeServerConfig
1819
1920
20 def main(args):
21 def main(args: List[str]) -> None:
2122 action = args[1] if len(args) > 1 and args[1] == "read" else None
2223 # If we're reading a key in the config file, then `args[1]` will be `read` and `args[2]`
2324 # will be the key to read.
00 # Copyright 2015, 2016 OpenMarket Ltd
1 # Copyright 2021 The Matrix.org Foundation C.I.C.
12 #
23 # Licensed under the Apache License, Version 2.0 (the "License");
34 # you may not use this file except in compliance with the License.
1213 # limitations under the License.
1314
1415 import logging
15 from typing import Dict
16 from typing import Dict, List
1617 from urllib import parse as urlparse
1718
1819 import yaml
1920 from netaddr import IPSet
2021
2122 from synapse.appservice import ApplicationService
22 from synapse.types import UserID
23 from synapse.types import JsonDict, UserID
2324
2425 from ._base import Config, ConfigError
2526
2930 class AppServiceConfig(Config):
3031 section = "appservice"
3132
32 def read_config(self, config, **kwargs):
33 def read_config(self, config, **kwargs) -> None:
3334 self.app_service_config_files = config.get("app_service_config_files", [])
3435 self.notify_appservices = config.get("notify_appservices", True)
3536 self.track_appservice_user_ips = config.get("track_appservice_user_ips", False)
3637
37 def generate_config_section(cls, **kwargs):
38 def generate_config_section(cls, **kwargs) -> str:
3839 return """\
3940 # A list of application service config files to use
4041 #
4950 """
5051
5152
52 def load_appservices(hostname, config_files):
53 def load_appservices(
54 hostname: str, config_files: List[str]
55 ) -> List[ApplicationService]:
5356 """Returns a list of Application Services from the config files."""
5457 if not isinstance(config_files, list):
5558 logger.warning("Expected %s to be a list of AS config files.", config_files)
9295 return appservices
9396
9497
95 def _load_appservice(hostname, as_info, config_filename):
98 def _load_appservice(
99 hostname: str, as_info: JsonDict, config_filename: str
100 ) -> ApplicationService:
96101 required_string_fields = ["id", "as_token", "hs_token", "sender_localpart"]
97102 for field in required_string_fields:
98103 if not isinstance(as_info.get(field), str):
114119 user_id = user.to_string()
115120
116121 # Rate limiting for users of this AS is on by default (excludes sender)
117 rate_limited = True
118 if isinstance(as_info.get("rate_limited"), bool):
119 rate_limited = as_info.get("rate_limited")
122 rate_limited = as_info.get("rate_limited")
123 if not isinstance(rate_limited, bool):
124 rate_limited = True
120125
121126 # namespace checks
122127 if not isinstance(as_info.get("namespaces"), dict):
0 # Copyright 2019 Matrix.org Foundation C.I.C.
0 # Copyright 2019-2021 Matrix.org Foundation C.I.C.
11 #
22 # Licensed under the Apache License, Version 2.0 (the "License");
33 # you may not use this file except in compliance with the License.
1616 import threading
1717 from typing import Callable, Dict, Optional
1818
19 import attr
20
1921 from synapse.python_dependencies import DependencyException, check_requirements
2022
2123 from ._base import Config, ConfigError
3335 _DEFAULT_EVENT_CACHE_SIZE = "10K"
3436
3537
38 @attr.s(slots=True, auto_attribs=True)
3639 class CacheProperties:
37 def __init__(self):
38 # The default factor size for all caches
39 self.default_factor_size = float(
40 os.environ.get(_CACHE_PREFIX, _DEFAULT_FACTOR_SIZE)
41 )
42 self.resize_all_caches_func = None
40 # The default factor size for all caches
41 default_factor_size: float = float(
42 os.environ.get(_CACHE_PREFIX, _DEFAULT_FACTOR_SIZE)
43 )
44 resize_all_caches_func: Optional[Callable[[], None]] = None
4345
4446
4547 properties = CacheProperties()
6163
6264 def add_resizable_cache(
6365 cache_name: str, cache_resize_callback: Callable[[float], None]
64 ):
66 ) -> None:
6567 """Register a cache that's size can dynamically change
6668
6769 Args:
9092 _environ = os.environ
9193
9294 @staticmethod
93 def reset():
95 def reset() -> None:
9496 """Resets the caches to their defaults. Used for tests."""
9597 properties.default_factor_size = float(
9698 os.environ.get(_CACHE_PREFIX, _DEFAULT_FACTOR_SIZE)
99101 with _CACHES_LOCK:
100102 _CACHES.clear()
101103
102 def generate_config_section(self, **kwargs):
104 def generate_config_section(self, **kwargs) -> str:
103105 return """\
104106 ## Caching ##
105107
161163 #sync_response_cache_duration: 2m
162164 """
163165
164 def read_config(self, config, **kwargs):
166 def read_config(self, config, **kwargs) -> None:
165167 self.event_cache_size = self.parse_size(
166168 config.get("event_cache_size", _DEFAULT_EVENT_CACHE_SIZE)
167169 )
231233 # needing an instance of Config
232234 properties.resize_all_caches_func = self.resize_all_caches
233235
234 def resize_all_caches(self):
236 def resize_all_caches(self) -> None:
235237 """Ensure all cache sizes are up to date
236238
237239 For each cache, run the mapped callback function with either
00 # Copyright 2015, 2016 OpenMarket Ltd
1 # Copyright 2021 The Matrix.org Foundation C.I.C.
12 #
23 # Licensed under the Apache License, Version 2.0 (the "License");
34 # you may not use this file except in compliance with the License.
2728
2829 section = "cas"
2930
30 def read_config(self, config, **kwargs):
31 def read_config(self, config, **kwargs) -> None:
3132 cas_config = config.get("cas_config", None)
3233 self.cas_enabled = cas_config and cas_config.get("enabled", True)
3334
5051 self.cas_displayname_attribute = None
5152 self.cas_required_attributes = []
5253
53 def generate_config_section(self, config_dir_path, server_name, **kwargs):
54 def generate_config_section(self, config_dir_path, server_name, **kwargs) -> str:
5455 return """\
5556 # Enable Central Authentication Service (CAS) for registration and login.
5657 #
00 # Copyright 2014-2016 OpenMarket Ltd
1 # Copyright 2020 The Matrix.org Foundation C.I.C.
1 # Copyright 2020-2021 The Matrix.org Foundation C.I.C.
22 #
33 # Licensed under the Apache License, Version 2.0 (the "License");
44 # you may not use this file except in compliance with the License.
1111 # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
1212 # See the License for the specific language governing permissions and
1313 # limitations under the License.
14 import argparse
1415 import logging
1516 import os
1617
118119
119120 self.databases = []
120121
121 def read_config(self, config, **kwargs):
122 def read_config(self, config, **kwargs) -> None:
122123 # We *experimentally* support specifying multiple databases via the
123124 # `databases` key. This is a map from a label to database config in the
124125 # same format as the `database` config option, plus an extra
162163 self.databases = [DatabaseConnectionConfig("master", database_config)]
163164 self.set_databasepath(database_path)
164165
165 def generate_config_section(self, data_dir_path, **kwargs):
166 def generate_config_section(self, data_dir_path, **kwargs) -> str:
166167 return DEFAULT_CONFIG % {
167168 "database_path": os.path.join(data_dir_path, "homeserver.db")
168169 }
169170
170 def read_arguments(self, args):
171 def read_arguments(self, args: argparse.Namespace) -> None:
171172 """
172173 Cases for the cli input:
173174 - If no databases are configured and no database_path is set, raise.
193194 else:
194195 logger.warning(NON_SQLITE_DATABASE_PATH_WARNING)
195196
196 def set_databasepath(self, database_path):
197 def set_databasepath(self, database_path: str) -> None:
197198
198199 if database_path != ":memory:":
199200 database_path = self.abspath(database_path)
201202 self.databases[0].config["args"]["database"] = database_path
202203
203204 @staticmethod
204 def add_arguments(parser):
205 def add_arguments(parser: argparse.ArgumentParser) -> None:
205206 db_group = parser.add_argument_group("database")
206207 db_group.add_argument(
207208 "-d",
4545
4646 # MSC3266 (room summary api)
4747 self.msc3266_enabled: bool = experimental.get("msc3266_enabled", False)
48
49 # MSC3030 (Jump to date API endpoint)
50 self.msc3030_enabled: bool = experimental.get("msc3030_enabled", False)
00 # Copyright 2014-2016 OpenMarket Ltd
1 # Copyright 2021 The Matrix.org Foundation C.I.C.
12 #
23 # Licensed under the Apache License, Version 2.0 (the "License");
34 # you may not use this file except in compliance with the License.
1718 import sys
1819 import threading
1920 from string import Template
20 from typing import TYPE_CHECKING, Any, Dict
21 from typing import TYPE_CHECKING, Any, Dict, Optional
2122
2223 import yaml
2324 from zope.interface import implementer
3940 from ._base import Config, ConfigError
4041
4142 if TYPE_CHECKING:
43 from synapse.config.homeserver import HomeServerConfig
4244 from synapse.server import HomeServer
4345
4446 DEFAULT_LOG_CONFIG = Template(
140142 class LoggingConfig(Config):
141143 section = "logging"
142144
143 def read_config(self, config, **kwargs):
145 def read_config(self, config, **kwargs) -> None:
144146 if config.get("log_file"):
145147 raise ConfigError(LOG_FILE_ERROR)
146148 self.log_config = self.abspath(config.get("log_config"))
147149 self.no_redirect_stdio = config.get("no_redirect_stdio", False)
148150
149 def generate_config_section(self, config_dir_path, server_name, **kwargs):
151 def generate_config_section(self, config_dir_path, server_name, **kwargs) -> str:
150152 log_config = os.path.join(config_dir_path, server_name + ".log.config")
151153 return (
152154 """\
160162 % locals()
161163 )
162164
163 def read_arguments(self, args):
165 def read_arguments(self, args: argparse.Namespace) -> None:
164166 if args.no_redirect_stdio is not None:
165167 self.no_redirect_stdio = args.no_redirect_stdio
166168 if args.log_file is not None:
167169 raise ConfigError(LOG_FILE_ERROR)
168170
169171 @staticmethod
170 def add_arguments(parser):
172 def add_arguments(parser: argparse.ArgumentParser) -> None:
171173 logging_group = parser.add_argument_group("logging")
172174 logging_group.add_argument(
173175 "-n",
196198 log_config_file.write(DEFAULT_LOG_CONFIG.substitute(log_file=log_file))
197199
198200
199 def _setup_stdlib_logging(config, log_config_path, logBeginner: LogBeginner) -> None:
201 def _setup_stdlib_logging(
202 config: "HomeServerConfig", log_config_path: Optional[str], logBeginner: LogBeginner
203 ) -> None:
200204 """
201205 Set up Python standard library logging.
202206 """
229233 log_metadata_filter = MetadataFilter({"server_name": config.server.server_name})
230234 old_factory = logging.getLogRecordFactory()
231235
232 def factory(*args, **kwargs):
236 def factory(*args: Any, **kwargs: Any) -> logging.LogRecord:
233237 record = old_factory(*args, **kwargs)
234238 log_context_filter.filter(record)
235239 log_metadata_filter.filter(record)
296300 logging.config.dictConfig(log_config)
297301
298302
299 def _reload_logging_config(log_config_path):
303 def _reload_logging_config(log_config_path: Optional[str]) -> None:
300304 """
301305 Reload the log configuration from the file and apply it.
302306 """
310314
311315 def setup_logging(
312316 hs: "HomeServer",
313 config,
314 use_worker_options=False,
317 config: "HomeServerConfig",
318 use_worker_options: bool = False,
315319 logBeginner: LogBeginner = globalLogBeginner,
316320 ) -> None:
317321 """
1313 # limitations under the License.
1414
1515 from collections import Counter
16 from typing import Collection, Iterable, List, Mapping, Optional, Tuple, Type
16 from typing import Any, Collection, Iterable, List, Mapping, Optional, Tuple, Type
1717
1818 import attr
1919
3535 class OIDCConfig(Config):
3636 section = "oidc"
3737
38 def read_config(self, config, **kwargs):
38 def read_config(self, config, **kwargs) -> None:
3939 self.oidc_providers = tuple(_parse_oidc_provider_configs(config))
4040 if not self.oidc_providers:
4141 return
6565 # OIDC is enabled if we have a provider
6666 return bool(self.oidc_providers)
6767
68 def generate_config_section(self, config_dir_path, server_name, **kwargs):
68 def generate_config_section(self, config_dir_path, server_name, **kwargs) -> str:
6969 return """\
7070 # List of OpenID Connect (OIDC) / OAuth 2.0 identity providers, for registration
7171 # and login.
494494 )
495495
496496
497 @attr.s(slots=True, frozen=True)
497 @attr.s(slots=True, frozen=True, auto_attribs=True)
498498 class OidcProviderClientSecretJwtKey:
499499 # a pem-encoded signing key
500 key = attr.ib(type=str)
500 key: str
501501
502502 # properties to include in the JWT header
503 jwt_header = attr.ib(type=Mapping[str, str])
503 jwt_header: Mapping[str, str]
504504
505505 # properties to include in the JWT payload.
506 jwt_payload = attr.ib(type=Mapping[str, str])
507
508
509 @attr.s(slots=True, frozen=True)
506 jwt_payload: Mapping[str, str]
507
508
509 @attr.s(slots=True, frozen=True, auto_attribs=True)
510510 class OidcProviderConfig:
511511 # a unique identifier for this identity provider. Used in the 'user_external_ids'
512512 # table, as well as the query/path parameter used in the login protocol.
513 idp_id = attr.ib(type=str)
513 idp_id: str
514514
515515 # user-facing name for this identity provider.
516 idp_name = attr.ib(type=str)
516 idp_name: str
517517
518518 # Optional MXC URI for icon for this IdP.
519 idp_icon = attr.ib(type=Optional[str])
519 idp_icon: Optional[str]
520520
521521 # Optional brand identifier for this IdP.
522 idp_brand = attr.ib(type=Optional[str])
522 idp_brand: Optional[str]
523523
524524 # whether the OIDC discovery mechanism is used to discover endpoints
525 discover = attr.ib(type=bool)
525 discover: bool
526526
527527 # the OIDC issuer. Used to validate tokens and (if discovery is enabled) to
528528 # discover the provider's endpoints.
529 issuer = attr.ib(type=str)
529 issuer: str
530530
531531 # oauth2 client id to use
532 client_id = attr.ib(type=str)
532 client_id: str
533533
534534 # oauth2 client secret to use. if `None`, use client_secret_jwt_key to generate
535535 # a secret.
536 client_secret = attr.ib(type=Optional[str])
536 client_secret: Optional[str]
537537
538538 # key to use to construct a JWT to use as a client secret. May be `None` if
539539 # `client_secret` is set.
540 client_secret_jwt_key = attr.ib(type=Optional[OidcProviderClientSecretJwtKey])
540 client_secret_jwt_key: Optional[OidcProviderClientSecretJwtKey]
541541
542542 # auth method to use when exchanging the token.
543543 # Valid values are 'client_secret_basic', 'client_secret_post' and
544544 # 'none'.
545 client_auth_method = attr.ib(type=str)
545 client_auth_method: str
546546
547547 # list of scopes to request
548 scopes = attr.ib(type=Collection[str])
548 scopes: Collection[str]
549549
550550 # the oauth2 authorization endpoint. Required if discovery is disabled.
551 authorization_endpoint = attr.ib(type=Optional[str])
551 authorization_endpoint: Optional[str]
552552
553553 # the oauth2 token endpoint. Required if discovery is disabled.
554 token_endpoint = attr.ib(type=Optional[str])
554 token_endpoint: Optional[str]
555555
556556 # the OIDC userinfo endpoint. Required if discovery is disabled and the
557557 # "openid" scope is not requested.
558 userinfo_endpoint = attr.ib(type=Optional[str])
558 userinfo_endpoint: Optional[str]
559559
560560 # URI where to fetch the JWKS. Required if discovery is disabled and the
561561 # "openid" scope is used.
562 jwks_uri = attr.ib(type=Optional[str])
562 jwks_uri: Optional[str]
563563
564564 # Whether to skip metadata verification
565 skip_verification = attr.ib(type=bool)
565 skip_verification: bool
566566
567567 # Whether to fetch the user profile from the userinfo endpoint. Valid
568568 # values are: "auto" or "userinfo_endpoint".
569 user_profile_method = attr.ib(type=str)
569 user_profile_method: str
570570
571571 # whether to allow a user logging in via OIDC to match a pre-existing account
572572 # instead of failing
573 allow_existing_users = attr.ib(type=bool)
573 allow_existing_users: bool
574574
575575 # the class of the user mapping provider
576 user_mapping_provider_class = attr.ib(type=Type)
576 user_mapping_provider_class: Type
577577
578578 # the config of the user mapping provider
579 user_mapping_provider_config = attr.ib()
579 user_mapping_provider_config: Any
580580
581581 # required attributes to require in userinfo to allow login/registration
582 attribute_requirements = attr.ib(type=List[SsoAttributeRequirement])
582 attribute_requirements: List[SsoAttributeRequirement]
00 # Copyright 2015, 2016 OpenMarket Ltd
1 # Copyright 2021 The Matrix.org Foundation C.I.C.
12 #
23 # Licensed under the Apache License, Version 2.0 (the "License");
34 # you may not use this file except in compliance with the License.
1011 # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
1112 # See the License for the specific language governing permissions and
1213 # limitations under the License.
14 import argparse
15 from typing import Optional
1316
1417 from synapse.api.constants import RoomCreationPreset
1518 from synapse.config._base import Config, ConfigError
112115 self.session_lifetime = session_lifetime
113116
114117 # The `refreshable_access_token_lifetime` applies for tokens that can be renewed
115 # using a refresh token, as per MSC2918. If it is `None`, the refresh
116 # token mechanism is disabled.
117 #
118 # Since it is incompatible with the `session_lifetime` mechanism, it is set to
119 # `None` by default if a `session_lifetime` is set.
118 # using a refresh token, as per MSC2918.
119 # If it is `None`, the refresh token mechanism is disabled.
120120 refreshable_access_token_lifetime = config.get(
121121 "refreshable_access_token_lifetime",
122 "5m" if session_lifetime is None else None,
122 "5m",
123123 )
124124 if refreshable_access_token_lifetime is not None:
125125 refreshable_access_token_lifetime = self.parse_duration(
126126 refreshable_access_token_lifetime
127127 )
128 self.refreshable_access_token_lifetime = refreshable_access_token_lifetime
128 self.refreshable_access_token_lifetime: Optional[
129 int
130 ] = refreshable_access_token_lifetime
129131
130132 if (
131 session_lifetime is not None
132 and refreshable_access_token_lifetime is not None
133 self.session_lifetime is not None
134 and "refreshable_access_token_lifetime" in config
133135 ):
134 raise ConfigError(
135 "The refresh token mechanism is incompatible with the "
136 "`session_lifetime` option. Consider disabling the "
137 "`session_lifetime` option or disabling the refresh token "
138 "mechanism by removing the `refreshable_access_token_lifetime` "
139 "option."
136 if self.session_lifetime < self.refreshable_access_token_lifetime:
137 raise ConfigError(
138 "Both `session_lifetime` and `refreshable_access_token_lifetime` "
139 "configuration options have been set, but `refreshable_access_token_lifetime` "
140 " exceeds `session_lifetime`!"
141 )
142
143 # The `nonrefreshable_access_token_lifetime` applies for tokens that can NOT be
144 # refreshed using a refresh token.
145 # If it is None, then these tokens last for the entire length of the session,
146 # which is infinite by default.
147 # The intention behind this configuration option is to help with requiring
148 # all clients to use refresh tokens, if the homeserver administrator requires.
149 nonrefreshable_access_token_lifetime = config.get(
150 "nonrefreshable_access_token_lifetime",
151 None,
152 )
153 if nonrefreshable_access_token_lifetime is not None:
154 nonrefreshable_access_token_lifetime = self.parse_duration(
155 nonrefreshable_access_token_lifetime
140156 )
157 self.nonrefreshable_access_token_lifetime = nonrefreshable_access_token_lifetime
158
159 if (
160 self.session_lifetime is not None
161 and self.nonrefreshable_access_token_lifetime is not None
162 ):
163 if self.session_lifetime < self.nonrefreshable_access_token_lifetime:
164 raise ConfigError(
165 "Both `session_lifetime` and `nonrefreshable_access_token_lifetime` "
166 "configuration options have been set, but `nonrefreshable_access_token_lifetime` "
167 " exceeds `session_lifetime`!"
168 )
169
170 refresh_token_lifetime = config.get("refresh_token_lifetime")
171 if refresh_token_lifetime is not None:
172 refresh_token_lifetime = self.parse_duration(refresh_token_lifetime)
173 self.refresh_token_lifetime: Optional[int] = refresh_token_lifetime
174
175 if (
176 self.session_lifetime is not None
177 and self.refresh_token_lifetime is not None
178 ):
179 if self.session_lifetime < self.refresh_token_lifetime:
180 raise ConfigError(
181 "Both `session_lifetime` and `refresh_token_lifetime` "
182 "configuration options have been set, but `refresh_token_lifetime` "
183 " exceeds `session_lifetime`!"
184 )
141185
142186 # The fallback template used for authenticating using a registration token
143187 self.registration_token_template = self.read_template("registration_token.html")
175219 #
176220 #session_lifetime: 24h
177221
222 # Time that an access token remains valid for, if the session is
223 # using refresh tokens.
224 # For more information about refresh tokens, please see the manual.
225 # Note that this only applies to clients which advertise support for
226 # refresh tokens.
227 #
228 # Note also that this is calculated at login time and refresh time:
229 # changes are not applied to existing sessions until they are refreshed.
230 #
231 # By default, this is 5 minutes.
232 #
233 #refreshable_access_token_lifetime: 5m
234
235 # Time that a refresh token remains valid for (provided that it is not
236 # exchanged for another one first).
237 # This option can be used to automatically log-out inactive sessions.
238 # Please see the manual for more information.
239 #
240 # Note also that this is calculated at login time and refresh time:
241 # changes are not applied to existing sessions until they are refreshed.
242 #
243 # By default, this is infinite.
244 #
245 #refresh_token_lifetime: 24h
246
247 # Time that an access token remains valid for, if the session is NOT
248 # using refresh tokens.
249 # Please note that not all clients support refresh tokens, so setting
250 # this to a short value may be inconvenient for some users who will
251 # then be logged out frequently.
252 #
253 # Note also that this is calculated at login time: changes are not applied
254 # retrospectively to existing sessions for users that have already logged in.
255 #
256 # By default, this is infinite.
257 #
258 #nonrefreshable_access_token_lifetime: 24h
259
178260 # The user must provide all of the below types of 3PID when registering.
179261 #
180262 #registrations_require_3pid:
368450 )
369451
370452 @staticmethod
371 def add_arguments(parser):
453 def add_arguments(parser: argparse.ArgumentParser) -> None:
372454 reg_group = parser.add_argument_group("registration")
373455 reg_group.add_argument(
374456 "--enable-registration",
377459 help="Enable registration for new users.",
378460 )
379461
380 def read_arguments(self, args):
462 def read_arguments(self, args: argparse.Namespace) -> None:
381463 if args.enable_registration is not None:
382464 self.enable_registration = strtobool(str(args.enable_registration))
1414 import logging
1515 import os
1616 from collections import namedtuple
17 from typing import Dict, List
17 from typing import Dict, List, Tuple
1818 from urllib.request import getproxies_environment # type: ignore
1919
2020 from synapse.config.server import DEFAULT_IP_RANGE_BLACKLIST, generate_ip_set
2121 from synapse.python_dependencies import DependencyException, check_requirements
22 from synapse.types import JsonDict
2223 from synapse.util.module_loader import load_module
2324
2425 from ._base import Config, ConfigError
5657 )
5758
5859
59 def parse_thumbnail_requirements(thumbnail_sizes):
60 def parse_thumbnail_requirements(
61 thumbnail_sizes: List[JsonDict],
62 ) -> Dict[str, Tuple[ThumbnailRequirement, ...]]:
6063 """Takes a list of dictionaries with "width", "height", and "method" keys
6164 and creates a map from image media types to the thumbnail size, thumbnailing
6265 method, and thumbnail media type to precalculate
6871 Dictionary mapping from media type string to list of
6972 ThumbnailRequirement tuples.
7073 """
71 requirements: Dict[str, List] = {}
74 requirements: Dict[str, List[ThumbnailRequirement]] = {}
7275 for size in thumbnail_sizes:
7376 width = size["width"]
7477 height = size["height"]
00 # Copyright 2018 New Vector Ltd
1 # Copyright 2019 The Matrix.org Foundation C.I.C.
1 # Copyright 2019-2021 The Matrix.org Foundation C.I.C.
22 #
33 # Licensed under the Apache License, Version 2.0 (the "License");
44 # you may not use this file except in compliance with the License.
1313 # limitations under the License.
1414
1515 import logging
16 from typing import Any, List
16 from typing import Any, List, Set
1717
1818 from synapse.config.sso import SsoAttributeRequirement
1919 from synapse.python_dependencies import DependencyException, check_requirements
20 from synapse.types import JsonDict
2021 from synapse.util.module_loader import load_module, load_python_module
2122
2223 from ._base import Config, ConfigError
3233 )
3334
3435
35 def _dict_merge(merge_dict, into_dict):
36 def _dict_merge(merge_dict: dict, into_dict: dict) -> None:
3637 """Do a deep merge of two dicts
3738
3839 Recursively merges `merge_dict` into `into_dict`:
4243 the value from `merge_dict`.
4344
4445 Args:
45 merge_dict (dict): dict to merge
46 into_dict (dict): target dict
46 merge_dict: dict to merge
47 into_dict: target dict to be modified
4748 """
4849 for k, v in merge_dict.items():
4950 if k not in into_dict:
6364 class SAML2Config(Config):
6465 section = "saml2"
6566
66 def read_config(self, config, **kwargs):
67 def read_config(self, config, **kwargs) -> None:
6768 self.saml2_enabled = False
6869
6970 saml2_config = config.get("saml2_config")
182183 )
183184
184185 def _default_saml_config_dict(
185 self, required_attributes: set, optional_attributes: set
186 ):
186 self, required_attributes: Set[str], optional_attributes: Set[str]
187 ) -> JsonDict:
187188 """Generate a configuration dictionary with required and optional attributes that
188189 will be needed to process new user registration
189190
194195 additional information to Synapse user accounts, but are not required
195196
196197 Returns:
197 dict: A SAML configuration dictionary
198 A SAML configuration dictionary
198199 """
199200 import saml2
200201
221222 },
222223 }
223224
224 def generate_config_section(self, config_dir_path, server_name, **kwargs):
225 def generate_config_section(self, config_dir_path, server_name, **kwargs) -> str:
225226 return """\
226227 ## Single sign-on integration ##
227228
1111 # See the License for the specific language governing permissions and
1212 # limitations under the License.
1313
14 import argparse
1415 import itertools
1516 import logging
1617 import os.path
2627 from twisted.conch.ssh.keys import Key
2728
2829 from synapse.api.room_versions import KNOWN_ROOM_VERSIONS
30 from synapse.types import JsonDict
2931 from synapse.util.module_loader import load_module
3032 from synapse.util.stringutils import parse_and_validate_server_name
3133
12221224 % locals()
12231225 )
12241226
1225 def read_arguments(self, args):
1227 def read_arguments(self, args: argparse.Namespace) -> None:
12261228 if args.manhole is not None:
12271229 self.manhole = args.manhole
12281230 if args.daemonize is not None:
12311233 self.print_pidfile = args.print_pidfile
12321234
12331235 @staticmethod
1234 def add_arguments(parser):
1236 def add_arguments(parser: argparse.ArgumentParser) -> None:
12351237 server_group = parser.add_argument_group("server")
12361238 server_group.add_argument(
12371239 "-D",
12731275 )
12741276
12751277
1276 def is_threepid_reserved(reserved_threepids, threepid):
1278 def is_threepid_reserved(
1279 reserved_threepids: List[JsonDict], threepid: JsonDict
1280 ) -> bool:
12771281 """Check the threepid against the reserved threepid config
12781282 Args:
1279 reserved_threepids([dict]) - list of reserved threepids
1280 threepid(dict) - The threepid to test for
1283 reserved_threepids: List of reserved threepids
1284 threepid: The threepid to test for
12811285
12821286 Returns:
1283 boolean Is the threepid undertest reserved_user
1287 Is the threepid undertest reserved_user
12841288 """
12851289
12861290 for tp in reserved_threepids:
12891293 return False
12901294
12911295
1292 def read_gc_thresholds(thresholds):
1296 def read_gc_thresholds(
1297 thresholds: Optional[List[Any]],
1298 ) -> Optional[Tuple[int, int, int]]:
12931299 """Reads the three integer thresholds for garbage collection. Ensures that
12941300 the thresholds are integers if thresholds are supplied.
12951301 """
0 # Copyright 2020 The Matrix.org Foundation C.I.C.
0 # Copyright 2020-2021 The Matrix.org Foundation C.I.C.
11 #
22 # Licensed under the Apache License, Version 2.0 (the "License");
33 # you may not use this file except in compliance with the License.
2828 ---------------------------------------------------------------------------------------"""
2929
3030
31 @attr.s(frozen=True)
31 @attr.s(frozen=True, auto_attribs=True)
3232 class SsoAttributeRequirement:
3333 """Object describing a single requirement for SSO attributes."""
3434
35 attribute = attr.ib(type=str)
35 attribute: str
3636 # If a value is not given, than the attribute must simply exist.
37 value = attr.ib(type=Optional[str])
37 value: Optional[str]
3838
3939 JSON_SCHEMA = {
4040 "type": "object",
4848
4949 section = "sso"
5050
51 def read_config(self, config, **kwargs):
51 def read_config(self, config, **kwargs) -> None:
5252 sso_config: Dict[str, Any] = config.get("sso") or {}
5353
5454 # The sso-specific template_dir
105105 )
106106 self.sso_client_whitelist.append(login_fallback_url)
107107
108 def generate_config_section(self, **kwargs):
108 def generate_config_section(self, **kwargs) -> str:
109109 return """\
110110 # Additional settings to use with single-sign on systems such as OpenID Connect,
111111 # SAML2 and CAS.
00 # Copyright 2016 OpenMarket Ltd
1 # Copyright 2021 The Matrix.org Foundation C.I.C.
12 #
23 # Licensed under the Apache License, Version 2.0 (the "License");
34 # you may not use this file except in compliance with the License.
1112 # See the License for the specific language governing permissions and
1213 # limitations under the License.
1314
15 import argparse
1416 from typing import List, Union
1517
1618 import attr
342344 #worker_replication_secret: ""
343345 """
344346
345 def read_arguments(self, args):
347 def read_arguments(self, args: argparse.Namespace) -> None:
346348 # We support a bunch of command line arguments that override options in
347349 # the config. A lot of these options have a worker_* prefix when running
348350 # on workers so we also have to override them when command line options
666666 perspective_name,
667667 )
668668
669 request: JsonDict = {}
670 for queue_value in keys_to_fetch:
671 # there may be multiple requests for each server, so we have to merge
672 # them intelligently.
673 request_for_server = {
674 key_id: {
675 "minimum_valid_until_ts": queue_value.minimum_valid_until_ts,
676 }
677 for key_id in queue_value.key_ids
678 }
679 request.setdefault(queue_value.server_name, {}).update(request_for_server)
680
681 logger.debug("Request to notary server %s: %s", perspective_name, request)
682
669683 try:
670684 query_response = await self.client.post_json(
671685 destination=perspective_name,
672686 path="/_matrix/key/v2/query",
673 data={
674 "server_keys": {
675 queue_value.server_name: {
676 key_id: {
677 "minimum_valid_until_ts": queue_value.minimum_valid_until_ts,
678 }
679 for key_id in queue_value.key_ids
680 }
681 for queue_value in keys_to_fetch
682 }
683 },
687 data={"server_keys": request},
684688 )
685689 except (NotRetryingDestination, RequestSendFailed) as e:
686690 # these both have str() representations which we can't really improve upon
687691 raise KeyLookupError(str(e))
688692 except HttpResponseException as e:
689693 raise KeyLookupError("Remote server returned an error: %s" % (e,))
694
695 logger.debug(
696 "Response from notary server %s: %s", perspective_name, query_response
697 )
690698
691699 keys: Dict[str, Dict[str, FetchKeyResult]] = {}
692700 added_keys: List[Tuple[str, str, FetchKeyResult]] = []
321321 attributes by loading from the database.
322322 """
323323 if self.state_group is None:
324 # No state group means the event is an outlier. Usually the state_ids dicts are also
325 # pre-set to empty dicts, but they get reset when the context is serialized, so set
326 # them to empty dicts again here.
327 self._current_state_ids = {}
328 self._prev_state_ids = {}
324329 return
325330
326331 current_state_ids = await self._storage.state.get_state_ids_for_group(
305305 def serialize_event(
306306 e: Union[JsonDict, EventBase],
307307 time_now_ms: int,
308 *,
308309 as_client_event: bool = True,
309310 event_format: Callable[[JsonDict], JsonDict] = format_event_for_client_v1,
310311 token_id: Optional[str] = None,
392393 self,
393394 event: Union[JsonDict, EventBase],
394395 time_now: int,
395 bundle_relations: bool = True,
396 *,
397 bundle_aggregations: bool = True,
396398 **kwargs: Any,
397399 ) -> JsonDict:
398400 """Serializes a single event.
400402 Args:
401403 event: The event being serialized.
402404 time_now: The current time in milliseconds
403 bundle_relations: Whether to include the bundled relations for this
404 event.
405 bundle_aggregations: Whether to include the bundled aggregations for this
406 event. Only applies to non-state events. (State events never include
407 bundled aggregations.)
405408 **kwargs: Arguments to pass to `serialize_event`
406409
407410 Returns:
413416
414417 serialized_event = serialize_event(event, time_now, **kwargs)
415418
416 # If MSC1849 is enabled then we need to look if there are any relations
417 # we need to bundle in with the event.
418 # Do not bundle relations if the event has been redacted
419 if not event.internal_metadata.is_redacted() and (
420 self._msc1849_enabled and bundle_relations
419 # Check if there are any bundled aggregations to include with the event.
420 #
421 # Do not bundle aggregations if any of the following at true:
422 #
423 # * Support is disabled via the configuration or the caller.
424 # * The event is a state event.
425 # * The event has been redacted.
426 if (
427 self._msc1849_enabled
428 and bundle_aggregations
429 and not event.is_state()
430 and not event.internal_metadata.is_redacted()
421431 ):
422 await self._injected_bundled_relations(event, time_now, serialized_event)
432 await self._injected_bundled_aggregations(event, time_now, serialized_event)
423433
424434 return serialized_event
425435
426 async def _injected_bundled_relations(
436 async def _injected_bundled_aggregations(
427437 self, event: EventBase, time_now: int, serialized_event: JsonDict
428438 ) -> None:
429 """Potentially injects bundled relations into the unsigned portion of the serialized event.
439 """Potentially injects bundled aggregations into the unsigned portion of the serialized event.
430440
431441 Args:
432442 event: The event being serialized.
434444 serialized_event: The serialized event which may be modified.
435445
436446 """
447 # Do not bundle aggregations for an event which represents an edit or an
448 # annotation. It does not make sense for them to have related events.
449 relates_to = event.content.get("m.relates_to")
450 if isinstance(relates_to, (dict, frozendict)):
451 relation_type = relates_to.get("rel_type")
452 if relation_type in (RelationTypes.ANNOTATION, RelationTypes.REPLACE):
453 return
454
437455 event_id = event.event_id
438456
439 # The bundled relations to include.
440 relations = {}
457 # The bundled aggregations to include.
458 aggregations = {}
441459
442460 annotations = await self.store.get_aggregation_groups_for_event(event_id)
443461 if annotations.chunk:
444 relations[RelationTypes.ANNOTATION] = annotations.to_dict()
462 aggregations[RelationTypes.ANNOTATION] = annotations.to_dict()
445463
446464 references = await self.store.get_relations_for_event(
447465 event_id, RelationTypes.REFERENCE, direction="f"
448466 )
449467 if references.chunk:
450 relations[RelationTypes.REFERENCE] = references.to_dict()
468 aggregations[RelationTypes.REFERENCE] = references.to_dict()
451469
452470 edit = None
453471 if event.type == EventTypes.Message:
473491 else:
474492 serialized_event["content"].pop("m.relates_to", None)
475493
476 relations[RelationTypes.REPLACE] = {
494 aggregations[RelationTypes.REPLACE] = {
477495 "event_id": edit.event_id,
478496 "origin_server_ts": edit.origin_server_ts,
479497 "sender": edit.sender,
486504 latest_thread_event,
487505 ) = await self.store.get_thread_summary(event_id)
488506 if latest_thread_event:
489 relations[RelationTypes.THREAD] = {
490 # Don't bundle relations as this could recurse forever.
507 aggregations[RelationTypes.THREAD] = {
508 # Don't bundle aggregations as this could recurse forever.
491509 "latest_event": await self.serialize_event(
492 latest_thread_event, time_now, bundle_relations=False
510 latest_thread_event, time_now, bundle_aggregations=False
493511 ),
494512 "count": thread_count,
495513 }
496514
497 # If any bundled relations were found, include them.
498 if relations:
499 serialized_event["unsigned"].setdefault("m.relations", {}).update(relations)
515 # If any bundled aggregations were found, include them.
516 if aggregations:
517 serialized_event["unsigned"].setdefault("m.relations", {}).update(
518 aggregations
519 )
500520
501521 async def serialize_events(
502522 self, events: Iterable[Union[JsonDict, EventBase]], time_now: int, **kwargs: Any
127127 reset_expiry_on_get=False,
128128 )
129129
130 def _clear_tried_cache(self):
130 def _clear_tried_cache(self) -> None:
131131 """Clear pdu_destination_tried cache"""
132132 now = self._clock.time_msec()
133133
799799 no servers successfully handle the request.
800800 """
801801
802 async def send_request(destination) -> SendJoinResult:
802 async def send_request(destination: str) -> SendJoinResult:
803803 response = await self._do_send_join(room_version, destination, pdu)
804804
805805 # If an event was returned (and expected to be returned):
13941394 async def send_request(
13951395 destination: str,
13961396 ) -> Tuple[JsonDict, Sequence[JsonDict], Sequence[str]]:
1397 res = await self.transport_layer.get_room_hierarchy(
1398 destination=destination,
1399 room_id=room_id,
1400 suggested_only=suggested_only,
1401 )
1397 try:
1398 res = await self.transport_layer.get_room_hierarchy(
1399 destination=destination,
1400 room_id=room_id,
1401 suggested_only=suggested_only,
1402 )
1403 except HttpResponseException as e:
1404 # If an error is received that is due to an unrecognised endpoint,
1405 # fallback to the unstable endpoint. Otherwise consider it a
1406 # legitmate error and raise.
1407 if not self._is_unknown_endpoint(e):
1408 raise
1409
1410 logger.debug(
1411 "Couldn't fetch room hierarchy with the v1 API, falling back to the unstable API"
1412 )
1413
1414 res = await self.transport_layer.get_room_hierarchy_unstable(
1415 destination=destination,
1416 room_id=room_id,
1417 suggested_only=suggested_only,
1418 )
14021419
14031420 room = res.get("room")
14041421 if not isinstance(room, dict):
14481465 if e.code != 502:
14491466 raise
14501467
1468 logger.debug(
1469 "Couldn't fetch room hierarchy, falling back to the spaces API"
1470 )
1471
14511472 # Fallback to the old federation API and translate the results if
14521473 # no servers implement the new API.
14531474 #
14951516 self._get_room_hierarchy_cache[(room_id, suggested_only)] = result
14961517 return result
14971518
1519 async def timestamp_to_event(
1520 self, destination: str, room_id: str, timestamp: int, direction: str
1521 ) -> "TimestampToEventResponse":
1522 """
1523 Calls a remote federating server at `destination` asking for their
1524 closest event to the given timestamp in the given direction. Also
1525 validates the response to always return the expected keys or raises an
1526 error.
1527
1528 Args:
1529 destination: Domain name of the remote homeserver
1530 room_id: Room to fetch the event from
1531 timestamp: The point in time (inclusive) we should navigate from in
1532 the given direction to find the closest event.
1533 direction: ["f"|"b"] to indicate whether we should navigate forward
1534 or backward from the given timestamp to find the closest event.
1535
1536 Returns:
1537 A parsed TimestampToEventResponse including the closest event_id
1538 and origin_server_ts
1539
1540 Raises:
1541 Various exceptions when the request fails
1542 InvalidResponseError when the response does not have the correct
1543 keys or wrong types
1544 """
1545 remote_response = await self.transport_layer.timestamp_to_event(
1546 destination, room_id, timestamp, direction
1547 )
1548
1549 if not isinstance(remote_response, dict):
1550 raise InvalidResponseError(
1551 "Response must be a JSON dictionary but received %r" % remote_response
1552 )
1553
1554 try:
1555 return TimestampToEventResponse.from_json_dict(remote_response)
1556 except ValueError as e:
1557 raise InvalidResponseError(str(e))
1558
1559
1560 @attr.s(frozen=True, slots=True, auto_attribs=True)
1561 class TimestampToEventResponse:
1562 """Typed response dictionary for the federation /timestamp_to_event endpoint"""
1563
1564 event_id: str
1565 origin_server_ts: int
1566
1567 # the raw data, including the above keys
1568 data: JsonDict
1569
1570 @classmethod
1571 def from_json_dict(cls, d: JsonDict) -> "TimestampToEventResponse":
1572 """Parsed response from the federation /timestamp_to_event endpoint
1573
1574 Args:
1575 d: JSON object response to be parsed
1576
1577 Raises:
1578 ValueError if d does not the correct keys or they are the wrong types
1579 """
1580
1581 event_id = d.get("event_id")
1582 if not isinstance(event_id, str):
1583 raise ValueError(
1584 "Invalid response: 'event_id' must be a str but received %r" % event_id
1585 )
1586
1587 origin_server_ts = d.get("origin_server_ts")
1588 if not isinstance(origin_server_ts, int):
1589 raise ValueError(
1590 "Invalid response: 'origin_server_ts' must be a int but received %r"
1591 % origin_server_ts
1592 )
1593
1594 return cls(event_id, origin_server_ts, d)
1595
14981596
14991597 @attr.s(frozen=True, slots=True, auto_attribs=True)
15001598 class FederationSpaceSummaryEventResult:
00 # Copyright 2015, 2016 OpenMarket Ltd
11 # Copyright 2018 New Vector Ltd
2 # Copyright 2019 Matrix.org Federation C.I.C
2 # Copyright 2019-2021 Matrix.org Federation C.I.C
33 #
44 # Licensed under the Apache License, Version 2.0 (the "License");
55 # you may not use this file except in compliance with the License.
109109 super().__init__(hs)
110110
111111 self.handler = hs.get_federation_handler()
112 self.storage = hs.get_storage()
112113 self._federation_event_handler = hs.get_federation_event_handler()
113114 self.state = hs.get_state_handler()
114115 self._event_auth_handler = hs.get_event_auth_handler()
198199 res = self._transaction_dict_from_pdus(pdus)
199200
200201 return 200, res
202
203 async def on_timestamp_to_event_request(
204 self, origin: str, room_id: str, timestamp: int, direction: str
205 ) -> Tuple[int, Dict[str, Any]]:
206 """When we receive a federated `/timestamp_to_event` request,
207 handle all of the logic for validating and fetching the event.
208
209 Args:
210 origin: The server we received the event from
211 room_id: Room to fetch the event from
212 timestamp: The point in time (inclusive) we should navigate from in
213 the given direction to find the closest event.
214 direction: ["f"|"b"] to indicate whether we should navigate forward
215 or backward from the given timestamp to find the closest event.
216
217 Returns:
218 Tuple indicating the response status code and dictionary response
219 body including `event_id`.
220 """
221 with (await self._server_linearizer.queue((origin, room_id))):
222 origin_host, _ = parse_server_name(origin)
223 await self.check_server_matches_acl(origin_host, room_id)
224
225 # We only try to fetch data from the local database
226 event_id = await self.store.get_event_id_for_timestamp(
227 room_id, timestamp, direction
228 )
229 if event_id:
230 event = await self.store.get_event(
231 event_id, allow_none=False, allow_rejected=False
232 )
233
234 return 200, {
235 "event_id": event_id,
236 "origin_server_ts": event.origin_server_ts,
237 }
238
239 raise SynapseError(
240 404,
241 "Unable to find event from %s in direction %s" % (timestamp, direction),
242 errcode=Codes.NOT_FOUND,
243 )
201244
202245 async def on_incoming_transaction(
203246 self,
406449 # require callouts to other servers to fetch missing events), but
407450 # impose a limit to avoid going too crazy with ram/cpu.
408451
409 async def process_pdus_for_room(room_id: str):
452 async def process_pdus_for_room(room_id: str) -> None:
410453 with nested_logging_context(room_id):
411454 logger.debug("Processing PDUs for %s", room_id)
412455
503546
504547 async def on_state_ids_request(
505548 self, origin: str, room_id: str, event_id: str
506 ) -> Tuple[int, Dict[str, Any]]:
549 ) -> Tuple[int, JsonDict]:
507550 if not event_id:
508551 raise NotImplementedError("Specify an event")
509552
523566
524567 return 200, resp
525568
526 async def _on_state_ids_request_compute(self, room_id, event_id):
569 async def _on_state_ids_request_compute(
570 self, room_id: str, event_id: str
571 ) -> JsonDict:
527572 state_ids = await self.handler.get_state_ids_for_pdu(room_id, event_id)
528573 auth_chain_ids = await self.store.get_auth_chain_ids(room_id, state_ids)
529574 return {"pdu_ids": state_ids, "auth_chain_ids": auth_chain_ids}
612657 state = await self.store.get_events(state_ids)
613658
614659 time_now = self._clock.time_msec()
660 event_json = event.get_pdu_json()
615661 return {
616 "org.matrix.msc3083.v2.event": event.get_pdu_json(),
662 # TODO Remove the unstable prefix when servers have updated.
663 "org.matrix.msc3083.v2.event": event_json,
664 "event": event_json,
617665 "state": [p.get_pdu_json(time_now) for p in state.values()],
618666 "auth_chain": [p.get_pdu_json(time_now) for p in auth_chain],
619667 }
00 # Copyright 2014-2016 OpenMarket Ltd
1 # Copyright 2021 The Matrix.org Foundation C.I.C.
12 #
23 # Licensed under the Apache License, Version 2.0 (the "License");
34 # you may not use this file except in compliance with the License.
2223
2324 from synapse.federation.units import Transaction
2425 from synapse.logging.utils import log_function
26 from synapse.storage.databases.main import DataStore
2527 from synapse.types import JsonDict
2628
2729 logger = logging.getLogger(__name__)
3032 class TransactionActions:
3133 """Defines persistence actions that relate to handling Transactions."""
3234
33 def __init__(self, datastore):
35 def __init__(self, datastore: DataStore):
3436 self.store = datastore
3537
3638 @log_function
00 # Copyright 2014-2016 OpenMarket Ltd
1 # Copyright 2021 The Matrix.org Foundation C.I.C.
12 #
23 # Licensed under the Apache License, Version 2.0 (the "License");
34 # you may not use this file except in compliance with the License.
349350 TypeId = "" # Unique string that ids the type. Must be overridden in sub classes.
350351
351352 @staticmethod
352 def from_data(data):
353 def from_data(data: JsonDict) -> "BaseFederationRow":
353354 """Parse the data from the federation stream into a row.
354355
355356 Args:
358359 """
359360 raise NotImplementedError()
360361
361 def to_data(self):
362 def to_data(self) -> JsonDict:
362363 """Serialize this row to be sent over the federation stream.
363364
364365 Returns:
367368 """
368369 raise NotImplementedError()
369370
370 def add_to_buffer(self, buff):
371 def add_to_buffer(self, buff: "ParsedFederationStreamData") -> None:
371372 """Add this row to the appropriate field in the buffer ready for this
372373 to be sent over federation.
373374
390391 TypeId = "pd"
391392
392393 @staticmethod
393 def from_data(data):
394 def from_data(data: JsonDict) -> "PresenceDestinationsRow":
394395 return PresenceDestinationsRow(
395396 state=UserPresenceState.from_dict(data["state"]), destinations=data["dests"]
396397 )
397398
398 def to_data(self):
399 def to_data(self) -> JsonDict:
399400 return {"state": self.state.as_dict(), "dests": self.destinations}
400401
401 def add_to_buffer(self, buff):
402 def add_to_buffer(self, buff: "ParsedFederationStreamData") -> None:
402403 buff.presence_destinations.append((self.state, self.destinations))
403404
404405
416417 TypeId = "k"
417418
418419 @staticmethod
419 def from_data(data):
420 def from_data(data: JsonDict) -> "KeyedEduRow":
420421 return KeyedEduRow(key=tuple(data["key"]), edu=Edu(**data["edu"]))
421422
422 def to_data(self):
423 def to_data(self) -> JsonDict:
423424 return {"key": self.key, "edu": self.edu.get_internal_dict()}
424425
425 def add_to_buffer(self, buff):
426 def add_to_buffer(self, buff: "ParsedFederationStreamData") -> None:
426427 buff.keyed_edus.setdefault(self.edu.destination, {})[self.key] = self.edu
427428
428429
432433 TypeId = "e"
433434
434435 @staticmethod
435 def from_data(data):
436 def from_data(data: JsonDict) -> "EduRow":
436437 return EduRow(Edu(**data))
437438
438 def to_data(self):
439 def to_data(self) -> JsonDict:
439440 return self.edu.get_internal_dict()
440441
441 def add_to_buffer(self, buff):
442 def add_to_buffer(self, buff: "ParsedFederationStreamData") -> None:
442443 buff.edus.setdefault(self.edu.destination, []).append(self.edu)
443444
444445
00 # Copyright 2014-2016 OpenMarket Ltd
11 # Copyright 2019 New Vector Ltd
2 # Copyright 2021 The Matrix.org Foundation C.I.C.
23 #
34 # Licensed under the Apache License, Version 2.0 (the "License");
45 # you may not use this file except in compliance with the License.
1314 # limitations under the License.
1415 import datetime
1516 import logging
16 from typing import TYPE_CHECKING, Dict, Hashable, Iterable, List, Optional, Tuple
17 from types import TracebackType
18 from typing import TYPE_CHECKING, Dict, Hashable, Iterable, List, Optional, Tuple, Type
1719
1820 import attr
1921 from prometheus_client import Counter
212214 self._pending_edus_keyed[(edu.edu_type, key)] = edu
213215 self.attempt_new_transaction()
214216
215 def send_edu(self, edu) -> None:
217 def send_edu(self, edu: Edu) -> None:
216218 self._pending_edus.append(edu)
217219 self.attempt_new_transaction()
218220
700702
701703 return self._pdus, pending_edus
702704
703 async def __aexit__(self, exc_type, exc, tb):
705 async def __aexit__(
706 self,
707 exc_type: Optional[Type[BaseException]],
708 exc: Optional[BaseException],
709 tb: Optional[TracebackType],
710 ) -> None:
704711 if exc_type is not None:
705712 # Failed to send transaction, so we bail out.
706713 return
2020 Callable,
2121 Collection,
2222 Dict,
23 Generator,
2324 Iterable,
2425 List,
2526 Mapping,
148149 )
149150
150151 @log_function
152 async def timestamp_to_event(
153 self, destination: str, room_id: str, timestamp: int, direction: str
154 ) -> Union[JsonDict, List]:
155 """
156 Calls a remote federating server at `destination` asking for their
157 closest event to the given timestamp in the given direction.
158
159 Args:
160 destination: Domain name of the remote homeserver
161 room_id: Room to fetch the event from
162 timestamp: The point in time (inclusive) we should navigate from in
163 the given direction to find the closest event.
164 direction: ["f"|"b"] to indicate whether we should navigate forward
165 or backward from the given timestamp to find the closest event.
166
167 Returns:
168 Response dict received from the remote homeserver.
169
170 Raises:
171 Various exceptions when the request fails
172 """
173 path = _create_path(
174 FEDERATION_UNSTABLE_PREFIX,
175 "/org.matrix.msc3030/timestamp_to_event/%s",
176 room_id,
177 )
178
179 args = {"ts": [str(timestamp)], "dir": [direction]}
180
181 remote_response = await self.client.get_json(
182 destination, path=path, args=args, try_trailing_slash_on_400=True
183 )
184
185 return remote_response
186
187 @log_function
151188 async def send_transaction(
152189 self,
153190 transaction: Transaction,
198235
199236 @log_function
200237 async def make_query(
201 self, destination, query_type, args, retry_on_dns_fail, ignore_backoff=False
202 ):
238 self,
239 destination: str,
240 query_type: str,
241 args: dict,
242 retry_on_dns_fail: bool,
243 ignore_backoff: bool = False,
244 ) -> JsonDict:
203245 path = _create_v1_path("/query/%s", query_type)
204246
205 content = await self.client.get_json(
247 return await self.client.get_json(
206248 destination=destination,
207249 path=path,
208250 args=args,
210252 timeout=10000,
211253 ignore_backoff=ignore_backoff,
212254 )
213
214 return content
215255
216256 @log_function
217257 async def make_membership_event(
11911231 )
11921232
11931233 async def get_room_hierarchy(
1194 self,
1195 destination: str,
1196 room_id: str,
1197 suggested_only: bool,
1234 self, destination: str, room_id: str, suggested_only: bool
1235 ) -> JsonDict:
1236 """
1237 Args:
1238 destination: The remote server
1239 room_id: The room ID to ask about.
1240 suggested_only: if True, only suggested rooms will be returned
1241 """
1242 path = _create_v1_path("/hierarchy/%s", room_id)
1243
1244 return await self.client.get_json(
1245 destination=destination,
1246 path=path,
1247 args={"suggested_only": "true" if suggested_only else "false"},
1248 )
1249
1250 async def get_room_hierarchy_unstable(
1251 self, destination: str, room_id: str, suggested_only: bool
11981252 ) -> JsonDict:
11991253 """
12001254 Args:
12661320
12671321
12681322 @ijson.coroutine
1269 def _event_parser(event_dict: JsonDict):
1323 def _event_parser(event_dict: JsonDict) -> Generator[None, Tuple[str, Any], None]:
12701324 """Helper function for use with `ijson.kvitems_coro` to parse key-value pairs
12711325 to add them to a given dictionary.
12721326 """
12771331
12781332
12791333 @ijson.coroutine
1280 def _event_list_parser(room_version: RoomVersion, events: List[EventBase]):
1334 def _event_list_parser(
1335 room_version: RoomVersion, events: List[EventBase]
1336 ) -> Generator[None, JsonDict, None]:
12811337 """Helper function for use with `ijson.items_coro` to parse an array of
12821338 events and add them to the given list.
12831339 """
13161372 prefix + "auth_chain.item",
13171373 use_float=True,
13181374 )
1319 self._coro_event = ijson.kvitems_coro(
1375 # TODO Remove the unstable prefix when servers have updated.
1376 #
1377 # By re-using the same event dictionary this will cause the parsing of
1378 # org.matrix.msc3083.v2.event and event to stomp over each other.
1379 # Generally this should be fine.
1380 self._coro_unstable_event = ijson.kvitems_coro(
13201381 _event_parser(self._response.event_dict),
13211382 prefix + "org.matrix.msc3083.v2.event",
13221383 use_float=True,
13231384 )
1385 self._coro_event = ijson.kvitems_coro(
1386 _event_parser(self._response.event_dict),
1387 prefix + "event",
1388 use_float=True,
1389 )
13241390
13251391 def write(self, data: bytes) -> int:
13261392 self._coro_state.send(data)
13271393 self._coro_auth.send(data)
1394 self._coro_unstable_event.send(data)
13281395 self._coro_event.send(data)
13291396
13301397 return len(data)
2121 Authenticator,
2222 BaseFederationServlet,
2323 )
24 from synapse.federation.transport.server.federation import FEDERATION_SERVLET_CLASSES
24 from synapse.federation.transport.server.federation import (
25 FEDERATION_SERVLET_CLASSES,
26 FederationTimestampLookupServlet,
27 )
2528 from synapse.federation.transport.server.groups_local import GROUP_LOCAL_SERVLET_CLASSES
2629 from synapse.federation.transport.server.groups_server import (
2730 GROUP_SERVER_SERVLET_CLASSES,
298301 authenticator: Authenticator,
299302 ratelimiter: FederationRateLimiter,
300303 servlet_groups: Optional[Iterable[str]] = None,
301 ):
304 ) -> None:
302305 """Initialize and register servlet classes.
303306
304307 Will by default register all servlets. For custom behaviour, pass in
323326 )
324327
325328 for servletclass in DEFAULT_SERVLET_GROUPS[servlet_group]:
329 # Only allow the `/timestamp_to_event` servlet if msc3030 is enabled
330 if (
331 servletclass == FederationTimestampLookupServlet
332 and not hs.config.experimental.msc3030_enabled
333 ):
334 continue
335
326336 servletclass(
327337 hs=hs,
328338 authenticator=authenticator,
1414 import functools
1515 import logging
1616 import re
17 from typing import Any, Awaitable, Callable, Optional, Tuple, cast
1718
1819 from synapse.api.errors import Codes, FederationDeniedError, SynapseError
1920 from synapse.api.urls import FEDERATION_V1_PREFIX
21 from synapse.http.server import HttpServer, ServletCallback
2022 from synapse.http.servlet import parse_json_object_from_request
23 from synapse.http.site import SynapseRequest
2124 from synapse.logging import opentracing
2225 from synapse.logging.context import run_in_background
2326 from synapse.logging.opentracing import (
2831 whitelisted_homeserver,
2932 )
3033 from synapse.server import HomeServer
34 from synapse.types import JsonDict
3135 from synapse.util.ratelimitutils import FederationRateLimiter
3236 from synapse.util.stringutils import parse_and_validate_server_name
3337
5862 self.replication_client = hs.get_tcp_replication()
5963
6064 # A method just so we can pass 'self' as the authenticator to the Servlets
61 async def authenticate_request(self, request, content):
65 async def authenticate_request(
66 self, request: SynapseRequest, content: Optional[JsonDict]
67 ) -> str:
6268 now = self._clock.time_msec()
63 json_request = {
69 json_request: JsonDict = {
6470 "method": request.method.decode("ascii"),
6571 "uri": request.uri.decode("ascii"),
6672 "destination": self.server_name,
113119
114120 return origin
115121
116 async def _reset_retry_timings(self, origin):
122 async def _reset_retry_timings(self, origin: str) -> None:
117123 try:
118124 logger.info("Marking origin %r as up", origin)
119125 await self.store.set_destination_retry_timings(origin, None, 0, 0)
132138 logger.exception("Error resetting retry timings on %s", origin)
133139
134140
135 def _parse_auth_header(header_bytes):
141 def _parse_auth_header(header_bytes: bytes) -> Tuple[str, str, str]:
136142 """Parse an X-Matrix auth header
137143
138144 Args:
139 header_bytes (bytes): header value
145 header_bytes: header value
140146
141147 Returns:
142 Tuple[str, str, str]: origin, key id, signature.
148 origin, key id, signature.
143149
144150 Raises:
145151 AuthenticationError if the header could not be parsed
147153 try:
148154 header_str = header_bytes.decode("utf-8")
149155 params = header_str.split(" ")[1].split(",")
150 param_dict = dict(kv.split("=") for kv in params)
151
152 def strip_quotes(value):
156 param_dict = {k: v for k, v in (kv.split("=", maxsplit=1) for kv in params)}
157
158 def strip_quotes(value: str) -> str:
153159 if value.startswith('"'):
154160 return value[1:-1]
155161 else:
232238 self.ratelimiter = ratelimiter
233239 self.server_name = server_name
234240
235 def _wrap(self, func):
241 def _wrap(self, func: Callable[..., Awaitable[Tuple[int, Any]]]) -> ServletCallback:
236242 authenticator = self.authenticator
237243 ratelimiter = self.ratelimiter
238244
239245 @functools.wraps(func)
240 async def new_func(request, *args, **kwargs):
246 async def new_func(
247 request: SynapseRequest, *args: Any, **kwargs: str
248 ) -> Optional[Tuple[int, Any]]:
241249 """A callback which can be passed to HttpServer.RegisterPaths
242250
243251 Args:
244 request (twisted.web.http.Request):
252 request:
245253 *args: unused?
246 **kwargs (dict[unicode, unicode]): the dict mapping keys to path
247 components as specified in the path match regexp.
254 **kwargs: the dict mapping keys to path components as specified
255 in the path match regexp.
248256
249257 Returns:
250 Tuple[int, object]|None: (response code, response object) as returned by
251 the callback method. None if the request has already been handled.
258 (response code, response object) as returned by the callback method.
259 None if the request has already been handled.
252260 """
253261 content = None
254262 if request.method in [b"PUT", b"POST"]:
256264 content = parse_json_object_from_request(request)
257265
258266 try:
259 origin = await authenticator.authenticate_request(request, content)
267 origin: Optional[str] = await authenticator.authenticate_request(
268 request, content
269 )
260270 except NoAuthenticationError:
261271 origin = None
262272 if self.REQUIRE_AUTH:
300310 "client disconnected before we started processing "
301311 "request"
302312 )
303 return -1, None
313 return None
304314 response = await func(
305315 origin, content, request.args, *args, **kwargs
306316 )
311321
312322 return response
313323
314 return new_func
315
316 def register(self, server):
324 return cast(ServletCallback, new_func)
325
326 def register(self, server: HttpServer) -> None:
317327 pattern = re.compile("^" + self.PREFIX + self.PATH + "$")
318328
319329 for method in ("GET", "PUT", "POST"):
171171 return 400, {"error": "Did not include limit param"}
172172
173173 return await self.handler.on_backfill_request(origin, room_id, versions, limit)
174
175
176 class FederationTimestampLookupServlet(BaseFederationServerServlet):
177 """
178 API endpoint to fetch the `event_id` of the closest event to the given
179 timestamp (`ts` query parameter) in the given direction (`dir` query
180 parameter).
181
182 Useful for other homeservers when they're unable to find an event locally.
183
184 `ts` is a timestamp in milliseconds where we will find the closest event in
185 the given direction.
186
187 `dir` can be `f` or `b` to indicate forwards and backwards in time from the
188 given timestamp.
189
190 GET /_matrix/federation/unstable/org.matrix.msc3030/timestamp_to_event/<roomID>?ts=<timestamp>&dir=<direction>
191 {
192 "event_id": ...
193 }
194 """
195
196 PATH = "/timestamp_to_event/(?P<room_id>[^/]*)/?"
197 PREFIX = FEDERATION_UNSTABLE_PREFIX + "/org.matrix.msc3030"
198
199 async def on_GET(
200 self,
201 origin: str,
202 content: Literal[None],
203 query: Dict[bytes, List[bytes]],
204 room_id: str,
205 ) -> Tuple[int, JsonDict]:
206 timestamp = parse_integer_from_args(query, "ts", required=True)
207 direction = parse_string_from_args(
208 query, "dir", default="f", allowed_values=["f", "b"], required=True
209 )
210
211 return await self.handler.on_timestamp_to_event_request(
212 origin, room_id, timestamp, direction
213 )
174214
175215
176216 class FederationQueryServlet(BaseFederationServerServlet):
610650
611651
612652 class FederationRoomHierarchyServlet(BaseFederationServlet):
613 PREFIX = FEDERATION_UNSTABLE_PREFIX + "/org.matrix.msc2946"
614653 PATH = "/hierarchy/(?P<room_id>[^/]*)"
615654
616655 def __init__(
636675 )
637676
638677
678 class FederationRoomHierarchyUnstableServlet(FederationRoomHierarchyServlet):
679 PREFIX = FEDERATION_UNSTABLE_PREFIX + "/org.matrix.msc2946"
680
681
639682 class RoomComplexityServlet(BaseFederationServlet):
640683 """
641684 Indicates to other servers how complex (and therefore likely
679722 FederationStateV1Servlet,
680723 FederationStateIdsServlet,
681724 FederationBackfillServlet,
725 FederationTimestampLookupServlet,
682726 FederationQueryServlet,
683727 FederationMakeJoinServlet,
684728 FederationMakeLeaveServlet,
700744 RoomComplexityServlet,
701745 FederationSpaceSummaryServlet,
702746 FederationRoomHierarchyServlet,
747 FederationRoomHierarchyUnstableServlet,
703748 FederationV1SendKnockServlet,
704749 FederationMakeKnockServlet,
705750 )
1717 import unicodedata
1818 import urllib.parse
1919 from binascii import crc32
20 from http import HTTPStatus
2021 from typing import (
2122 TYPE_CHECKING,
2223 Any,
3738 import bcrypt
3839 import pymacaroons
3940 import unpaddedbase64
41 from pymacaroons.exceptions import MacaroonVerificationFailedException
4042
4143 from twisted.web.server import Request
4244
180182
181183 user_id = attr.ib(type=str)
182184
183 # the SSO Identity Provider that the user authenticated with, to get this token
184185 auth_provider_id = attr.ib(type=str)
186 """The SSO Identity Provider that the user authenticated with, to get this token."""
187
188 auth_provider_session_id = attr.ib(type=Optional[str])
189 """The session ID advertised by the SSO Identity Provider."""
185190
186191
187192 class AuthHandler:
755760 async def refresh_token(
756761 self,
757762 refresh_token: str,
758 valid_until_ms: Optional[int],
759 ) -> Tuple[str, str]:
763 access_token_valid_until_ms: Optional[int],
764 refresh_token_valid_until_ms: Optional[int],
765 ) -> Tuple[str, str, Optional[int]]:
760766 """
761767 Consumes a refresh token and generate both a new access token and a new refresh token from it.
762768
763769 The consumed refresh token is considered invalid after the first use of the new access token or the new refresh token.
764770
771 The lifetime of both the access token and refresh token will be capped so that they
772 do not exceed the session's ultimate expiry time, if applicable.
773
765774 Args:
766775 refresh_token: The token to consume.
767 valid_until_ms: The expiration timestamp of the new access token.
768
776 access_token_valid_until_ms: The expiration timestamp of the new access token.
777 None if the access token does not expire.
778 refresh_token_valid_until_ms: The expiration timestamp of the new refresh token.
779 None if the refresh token does not expire.
769780 Returns:
770 A tuple containing the new access token and refresh token
781 A tuple containing:
782 - the new access token
783 - the new refresh token
784 - the actual expiry time of the access token, which may be earlier than
785 `access_token_valid_until_ms`.
771786 """
772787
773788 # Verify the token signature first before looking up the token
774789 if not self._verify_refresh_token(refresh_token):
775 raise SynapseError(401, "invalid refresh token", Codes.UNKNOWN_TOKEN)
790 raise SynapseError(
791 HTTPStatus.UNAUTHORIZED, "invalid refresh token", Codes.UNKNOWN_TOKEN
792 )
776793
777794 existing_token = await self.store.lookup_refresh_token(refresh_token)
778795 if existing_token is None:
779 raise SynapseError(401, "refresh token does not exist", Codes.UNKNOWN_TOKEN)
796 raise SynapseError(
797 HTTPStatus.UNAUTHORIZED,
798 "refresh token does not exist",
799 Codes.UNKNOWN_TOKEN,
800 )
780801
781802 if (
782803 existing_token.has_next_access_token_been_used
783804 or existing_token.has_next_refresh_token_been_refreshed
784805 ):
785806 raise SynapseError(
786 403, "refresh token isn't valid anymore", Codes.FORBIDDEN
787 )
807 HTTPStatus.FORBIDDEN,
808 "refresh token isn't valid anymore",
809 Codes.FORBIDDEN,
810 )
811
812 now_ms = self._clock.time_msec()
813
814 if existing_token.expiry_ts is not None and existing_token.expiry_ts < now_ms:
815
816 raise SynapseError(
817 HTTPStatus.FORBIDDEN,
818 "The supplied refresh token has expired",
819 Codes.FORBIDDEN,
820 )
821
822 if existing_token.ultimate_session_expiry_ts is not None:
823 # This session has a bounded lifetime, even across refreshes.
824
825 if access_token_valid_until_ms is not None:
826 access_token_valid_until_ms = min(
827 access_token_valid_until_ms,
828 existing_token.ultimate_session_expiry_ts,
829 )
830 else:
831 access_token_valid_until_ms = existing_token.ultimate_session_expiry_ts
832
833 if refresh_token_valid_until_ms is not None:
834 refresh_token_valid_until_ms = min(
835 refresh_token_valid_until_ms,
836 existing_token.ultimate_session_expiry_ts,
837 )
838 else:
839 refresh_token_valid_until_ms = existing_token.ultimate_session_expiry_ts
840 if existing_token.ultimate_session_expiry_ts < now_ms:
841 raise SynapseError(
842 HTTPStatus.FORBIDDEN,
843 "The session has expired and can no longer be refreshed",
844 Codes.FORBIDDEN,
845 )
788846
789847 (
790848 new_refresh_token,
791849 new_refresh_token_id,
792850 ) = await self.create_refresh_token_for_user_id(
793 user_id=existing_token.user_id, device_id=existing_token.device_id
851 user_id=existing_token.user_id,
852 device_id=existing_token.device_id,
853 expiry_ts=refresh_token_valid_until_ms,
854 ultimate_session_expiry_ts=existing_token.ultimate_session_expiry_ts,
794855 )
795856 access_token = await self.create_access_token_for_user_id(
796857 user_id=existing_token.user_id,
797858 device_id=existing_token.device_id,
798 valid_until_ms=valid_until_ms,
859 valid_until_ms=access_token_valid_until_ms,
799860 refresh_token_id=new_refresh_token_id,
800861 )
801862 await self.store.replace_refresh_token(
802863 existing_token.token_id, new_refresh_token_id
803864 )
804 return access_token, new_refresh_token
865 return access_token, new_refresh_token, access_token_valid_until_ms
805866
806867 def _verify_refresh_token(self, token: str) -> bool:
807868 """
835896 self,
836897 user_id: str,
837898 device_id: str,
899 expiry_ts: Optional[int],
900 ultimate_session_expiry_ts: Optional[int],
838901 ) -> Tuple[str, int]:
839902 """
840903 Creates a new refresh token for the user with the given user ID.
842905 Args:
843906 user_id: canonical user ID
844907 device_id: the device ID to associate with the token.
908 expiry_ts (milliseconds since the epoch): Time after which the
909 refresh token cannot be used.
910 If None, the refresh token never expires until it has been used.
911 ultimate_session_expiry_ts (milliseconds since the epoch):
912 Time at which the session will end and can not be extended any
913 further.
914 If None, the session can be refreshed indefinitely.
845915
846916 Returns:
847917 The newly created refresh token and its ID in the database
851921 user_id=user_id,
852922 token=refresh_token,
853923 device_id=device_id,
924 expiry_ts=expiry_ts,
925 ultimate_session_expiry_ts=ultimate_session_expiry_ts,
854926 )
855927 return refresh_token, refresh_token_id
856928
15811653 client_redirect_url: str,
15821654 extra_attributes: Optional[JsonDict] = None,
15831655 new_user: bool = False,
1656 auth_provider_session_id: Optional[str] = None,
15841657 ) -> None:
15851658 """Having figured out a mxid for this user, complete the HTTP request
15861659
15961669 during successful login. Must be JSON serializable.
15971670 new_user: True if we should use wording appropriate to a user who has just
15981671 registered.
1672 auth_provider_session_id: The session ID from the SSO IdP received during login.
15991673 """
16001674 # If the account has been deactivated, do not proceed with the login
16011675 # flow.
16161690 extra_attributes,
16171691 new_user=new_user,
16181692 user_profile_data=profile,
1693 auth_provider_session_id=auth_provider_session_id,
16191694 )
16201695
16211696 def _complete_sso_login(
16271702 extra_attributes: Optional[JsonDict] = None,
16281703 new_user: bool = False,
16291704 user_profile_data: Optional[ProfileInfo] = None,
1705 auth_provider_session_id: Optional[str] = None,
16301706 ) -> None:
16311707 """
16321708 The synchronous portion of complete_sso_login.
16481724
16491725 # Create a login token
16501726 login_token = self.macaroon_gen.generate_short_term_login_token(
1651 registered_user_id, auth_provider_id=auth_provider_id
1727 registered_user_id,
1728 auth_provider_id=auth_provider_id,
1729 auth_provider_session_id=auth_provider_session_id,
16521730 )
16531731
16541732 # Append the login token to the original redirect URL (i.e. with its query
17531831 self,
17541832 user_id: str,
17551833 auth_provider_id: str,
1834 auth_provider_session_id: Optional[str] = None,
17561835 duration_in_ms: int = (2 * 60 * 1000),
17571836 ) -> str:
17581837 macaroon = self._generate_base_macaroon(user_id)
17611840 expiry = now + duration_in_ms
17621841 macaroon.add_first_party_caveat("time < %d" % (expiry,))
17631842 macaroon.add_first_party_caveat("auth_provider_id = %s" % (auth_provider_id,))
1843 if auth_provider_session_id is not None:
1844 macaroon.add_first_party_caveat(
1845 "auth_provider_session_id = %s" % (auth_provider_session_id,)
1846 )
17641847 return macaroon.serialize()
17651848
17661849 def verify_short_term_login_token(self, token: str) -> LoginTokenAttributes:
17811864 macaroon = pymacaroons.Macaroon.deserialize(token)
17821865 user_id = get_value_from_macaroon(macaroon, "user_id")
17831866 auth_provider_id = get_value_from_macaroon(macaroon, "auth_provider_id")
1867
1868 auth_provider_session_id: Optional[str] = None
1869 try:
1870 auth_provider_session_id = get_value_from_macaroon(
1871 macaroon, "auth_provider_session_id"
1872 )
1873 except MacaroonVerificationFailedException:
1874 pass
17841875
17851876 v = pymacaroons.Verifier()
17861877 v.satisfy_exact("gen = 1")
17871878 v.satisfy_exact("type = login")
17881879 v.satisfy_general(lambda c: c.startswith("user_id = "))
17891880 v.satisfy_general(lambda c: c.startswith("auth_provider_id = "))
1881 v.satisfy_general(lambda c: c.startswith("auth_provider_session_id = "))
17901882 satisfy_expiry(v, self.hs.get_clock().time_msec)
17911883 v.verify(macaroon, self.hs.config.key.macaroon_secret_key)
17921884
1793 return LoginTokenAttributes(user_id=user_id, auth_provider_id=auth_provider_id)
1885 return LoginTokenAttributes(
1886 user_id=user_id,
1887 auth_provider_id=auth_provider_id,
1888 auth_provider_session_id=auth_provider_session_id,
1889 )
17941890
17951891 def generate_delete_pusher_token(self, user_id: str) -> str:
17961892 macaroon = self._generate_base_macaroon(user_id)
300300 user_id: str,
301301 device_id: Optional[str],
302302 initial_device_display_name: Optional[str] = None,
303 auth_provider_id: Optional[str] = None,
304 auth_provider_session_id: Optional[str] = None,
303305 ) -> str:
304306 """
305307 If the given device has not been registered, register it with the
311313 user_id: @user:id
312314 device_id: device id supplied by client
313315 initial_device_display_name: device display name from client
316 auth_provider_id: The SSO IdP the user used, if any.
317 auth_provider_session_id: The session ID (sid) got from the SSO IdP.
314318 Returns:
315319 device id (generated if none was supplied)
316320 """
322326 user_id=user_id,
323327 device_id=device_id,
324328 initial_device_display_name=initial_device_display_name,
329 auth_provider_id=auth_provider_id,
330 auth_provider_session_id=auth_provider_session_id,
325331 )
326332 if new_device:
327333 await self.notify_device_update(user_id, [device_id])
336342 user_id=user_id,
337343 device_id=new_device_id,
338344 initial_device_display_name=initial_device_display_name,
345 auth_provider_id=auth_provider_id,
346 auth_provider_session_id=auth_provider_session_id,
339347 )
340348 if new_device:
341349 await self.notify_device_update(user_id, [new_device_id])
121121 events,
122122 time_now,
123123 as_client_event=as_client_event,
124 # We don't bundle "live" events, as otherwise clients
125 # will end up double counting annotations.
126 bundle_relations=False,
124 # Don't bundle aggregations as this is a deprecated API.
125 bundle_aggregations=False,
127126 )
128127
129128 chunk = {
6767 logger = logging.getLogger(__name__)
6868
6969
70 def get_domains_from_state(state: StateMap[EventBase]) -> List[Tuple[str, int]]:
71 """Get joined domains from state
72
73 Args:
74 state: State map from type/state key to event.
75
76 Returns:
77 Returns a list of servers with the lowest depth of their joins.
78 Sorted by lowest depth first.
79 """
80 joined_users = [
81 (state_key, int(event.depth))
82 for (e_type, state_key), event in state.items()
83 if e_type == EventTypes.Member and event.membership == Membership.JOIN
84 ]
85
86 joined_domains: Dict[str, int] = {}
87 for u, d in joined_users:
88 try:
89 dom = get_domain_from_id(u)
90 old_d = joined_domains.get(dom)
91 if old_d:
92 joined_domains[dom] = min(d, old_d)
93 else:
94 joined_domains[dom] = d
95 except Exception:
96 pass
97
98 return sorted(joined_domains.items(), key=lambda d: d[1])
99
100
70101 class FederationHandler:
71102 """Handles general incoming federation requests
72103
266297 # TODO: HEURISTIC ALERT.
267298
268299 curr_state = await self.state_handler.get_current_state(room_id)
269
270 def get_domains_from_state(state: StateMap[EventBase]) -> List[Tuple[str, int]]:
271 """Get joined domains from state
272
273 Args:
274 state: State map from type/state key to event.
275
276 Returns:
277 Returns a list of servers with the lowest depth of their joins.
278 Sorted by lowest depth first.
279 """
280 joined_users = [
281 (state_key, int(event.depth))
282 for (e_type, state_key), event in state.items()
283 if e_type == EventTypes.Member and event.membership == Membership.JOIN
284 ]
285
286 joined_domains: Dict[str, int] = {}
287 for u, d in joined_users:
288 try:
289 dom = get_domain_from_id(u)
290 old_d = joined_domains.get(dom)
291 if old_d:
292 joined_domains[dom] = min(d, old_d)
293 else:
294 joined_domains[dom] = d
295 except Exception:
296 pass
297
298 return sorted(joined_domains.items(), key=lambda d: d[1])
299300
300301 curr_domains = get_domains_from_state(curr_state)
301302
164164
165165 invite_event = await self.store.get_event(event.event_id)
166166 d["invite"] = await self._event_serializer.serialize_event(
167 invite_event, time_now, as_client_event
167 invite_event,
168 time_now,
169 # Don't bundle aggregations as this is a deprecated API.
170 bundle_aggregations=False,
171 as_client_event=as_client_event,
168172 )
169173
170174 rooms_ret.append(d)
215219 d["messages"] = {
216220 "chunk": (
217221 await self._event_serializer.serialize_events(
218 messages, time_now=time_now, as_client_event=as_client_event
222 messages,
223 time_now=time_now,
224 # Don't bundle aggregations as this is a deprecated API.
225 bundle_aggregations=False,
226 as_client_event=as_client_event,
219227 )
220228 ),
221229 "start": await start_token.to_string(self.store),
225233 d["state"] = await self._event_serializer.serialize_events(
226234 current_state.values(),
227235 time_now=time_now,
236 # Don't bundle aggregations as this is a deprecated API.
237 bundle_aggregations=False,
228238 as_client_event=as_client_event,
229239 )
230240
365375 "room_id": room_id,
366376 "messages": {
367377 "chunk": (
368 await self._event_serializer.serialize_events(messages, time_now)
378 # Don't bundle aggregations as this is a deprecated API.
379 await self._event_serializer.serialize_events(
380 messages, time_now, bundle_aggregations=False
381 )
369382 ),
370383 "start": await start_token.to_string(self.store),
371384 "end": await end_token.to_string(self.store),
372385 },
373386 "state": (
387 # Don't bundle aggregations as this is a deprecated API.
374388 await self._event_serializer.serialize_events(
375 room_state.values(), time_now
389 room_state.values(), time_now, bundle_aggregations=False
376390 )
377391 ),
378392 "presence": [],
391405
392406 # TODO: These concurrently
393407 time_now = self.clock.time_msec()
408 # Don't bundle aggregations as this is a deprecated API.
394409 state = await self._event_serializer.serialize_events(
395 current_state.values(), time_now
410 current_state.values(), time_now, bundle_aggregations=False
396411 )
397412
398413 now_token = self.hs.get_event_sources().get_current_token()
466481 "room_id": room_id,
467482 "messages": {
468483 "chunk": (
469 await self._event_serializer.serialize_events(messages, time_now)
484 # Don't bundle aggregations as this is a deprecated API.
485 await self._event_serializer.serialize_events(
486 messages, time_now, bundle_aggregations=False
487 )
470488 ),
471489 "start": await start_token.to_string(self.store),
472490 "end": await end_token.to_string(self.store),
246246 room_state = room_state_events[membership_event_id]
247247
248248 now = self.clock.time_msec()
249 events = await self._event_serializer.serialize_events(
250 room_state.values(),
251 now,
252 # We don't bother bundling aggregations in when asked for state
253 # events, as clients won't use them.
254 bundle_relations=False,
255 )
249 events = await self._event_serializer.serialize_events(room_state.values(), now)
256250 return events
257251
258252 async def get_joined_members(self, requester: Requester, room_id: str) -> dict:
2222 from authlib.jose import JsonWebToken, jwt
2323 from authlib.oauth2.auth import ClientAuth
2424 from authlib.oauth2.rfc6749.parameters import prepare_grant_uri
25 from authlib.oidc.core import CodeIDToken, ImplicitIDToken, UserInfo
25 from authlib.oidc.core import CodeIDToken, UserInfo
2626 from authlib.oidc.discovery import OpenIDProviderMetadata, get_well_known_url
2727 from jinja2 import Environment, Template
2828 from pymacaroons.exceptions import (
116116 for idp_id, p in self._providers.items():
117117 try:
118118 await p.load_metadata()
119 await p.load_jwks()
119 if not p._uses_userinfo:
120 await p.load_jwks()
120121 except Exception as e:
121122 raise Exception(
122123 "Error while initialising OIDC provider %r" % (idp_id,)
497498 return await self._jwks.get()
498499
499500 async def _load_jwks(self) -> JWKS:
500 if self._uses_userinfo:
501 # We're not using jwt signing, return an empty jwk set
502 return {"keys": []}
503
504501 metadata = await self.load_metadata()
505502
506503 # Load the JWKS using the `jwks_uri` metadata.
662659
663660 return UserInfo(resp)
664661
665 async def _parse_id_token(self, token: Token, nonce: str) -> UserInfo:
662 async def _parse_id_token(self, token: Token, nonce: str) -> CodeIDToken:
666663 """Return an instance of UserInfo from token's ``id_token``.
667664
668665 Args:
672669 request. This value should match the one inside the token.
673670
674671 Returns:
675 An object representing the user.
672 The decoded claims in the ID token.
676673 """
677674 metadata = await self.load_metadata()
678675 claims_params = {
683680 # If we got an `access_token`, there should be an `at_hash` claim
684681 # in the `id_token` that we can check against.
685682 claims_params["access_token"] = token["access_token"]
686 claims_cls = CodeIDToken
687 else:
688 claims_cls = ImplicitIDToken
689683
690684 alg_values = metadata.get("id_token_signing_alg_values_supported", ["RS256"])
691685 jwt = JsonWebToken(alg_values)
702696 claims = jwt.decode(
703697 id_token,
704698 key=jwk_set,
705 claims_cls=claims_cls,
699 claims_cls=CodeIDToken,
706700 claims_options=claim_options,
707701 claims_params=claims_params,
708702 )
712706 claims = jwt.decode(
713707 id_token,
714708 key=jwk_set,
715 claims_cls=claims_cls,
709 claims_cls=CodeIDToken,
716710 claims_options=claim_options,
717711 claims_params=claims_params,
718712 )
720714 logger.debug("Decoded id_token JWT %r; validating", claims)
721715
722716 claims.validate(leeway=120) # allows 2 min of clock skew
723 return UserInfo(claims)
717
718 return claims
724719
725720 async def handle_redirect_request(
726721 self,
836831
837832 logger.debug("Successfully obtained OAuth2 token data: %r", token)
838833
839 # Now that we have a token, get the userinfo, either by decoding the
840 # `id_token` or by fetching the `userinfo_endpoint`.
834 # If there is an id_token, it should be validated, regardless of the
835 # userinfo endpoint is used or not.
836 if token.get("id_token") is not None:
837 try:
838 id_token = await self._parse_id_token(token, nonce=session_data.nonce)
839 sid = id_token.get("sid")
840 except Exception as e:
841 logger.exception("Invalid id_token")
842 self._sso_handler.render_error(request, "invalid_token", str(e))
843 return
844 else:
845 id_token = None
846 sid = None
847
848 # Now that we have a token, get the userinfo either from the `id_token`
849 # claims or by fetching the `userinfo_endpoint`.
841850 if self._uses_userinfo:
842851 try:
843852 userinfo = await self._fetch_userinfo(token)
845854 logger.exception("Could not fetch userinfo")
846855 self._sso_handler.render_error(request, "fetch_error", str(e))
847856 return
857 elif id_token is not None:
858 userinfo = UserInfo(id_token)
848859 else:
849 try:
850 userinfo = await self._parse_id_token(token, nonce=session_data.nonce)
851 except Exception as e:
852 logger.exception("Invalid id_token")
853 self._sso_handler.render_error(request, "invalid_token", str(e))
854 return
860 logger.error("Missing id_token in token response")
861 self._sso_handler.render_error(
862 request, "invalid_token", "Missing id_token in token response"
863 )
864 return
855865
856866 # first check if we're doing a UIA
857867 if session_data.ui_auth_session_id:
883893 # Call the mapper to register/login the user
884894 try:
885895 await self._complete_oidc_login(
886 userinfo, token, request, session_data.client_redirect_url
896 userinfo, token, request, session_data.client_redirect_url, sid
887897 )
888898 except MappingException as e:
889899 logger.exception("Could not map user")
895905 token: Token,
896906 request: SynapseRequest,
897907 client_redirect_url: str,
908 sid: Optional[str],
898909 ) -> None:
899910 """Given a UserInfo response, complete the login flow
900911
10071018 oidc_response_to_user_attributes,
10081019 grandfather_existing_users,
10091020 extra_attributes,
1021 auth_provider_session_id=sid,
10101022 )
10111023
10121024 def _remote_id_from_userinfo(self, userinfo: UserInfo) -> str:
405405 force: set true to skip checking for joined users.
406406 """
407407 with await self.pagination_lock.write(room_id):
408 # check we know about the room
409 await self.store.get_room_version_id(room_id)
410
411408 # first check that we have no users in this room
412409 if not force:
413410 joined = await self.store.is_host_joined(room_id, self._server_name)
420420 self._on_shutdown,
421421 )
422422
423 def _on_shutdown(self) -> None:
423 async def _on_shutdown(self) -> None:
424424 if self._presence_enabled:
425425 self.hs.get_tcp_replication().send_command(
426426 ClearUserSyncsCommand(self.instance_id)
00 # Copyright 2014 - 2016 OpenMarket Ltd
1 # Copyright 2021 The Matrix.org Foundation C.I.C.
12 #
23 # Licensed under the Apache License, Version 2.0 (the "License");
34 # you may not use this file except in compliance with the License.
115116 self.pusher_pool = hs.get_pusherpool()
116117
117118 self.session_lifetime = hs.config.registration.session_lifetime
119 self.nonrefreshable_access_token_lifetime = (
120 hs.config.registration.nonrefreshable_access_token_lifetime
121 )
118122 self.refreshable_access_token_lifetime = (
119123 hs.config.registration.refreshable_access_token_lifetime
120124 )
125 self.refresh_token_lifetime = hs.config.registration.refresh_token_lifetime
121126
122127 init_counters_for_auth_provider("")
123128
740745 is_appservice_ghost: bool = False,
741746 auth_provider_id: Optional[str] = None,
742747 should_issue_refresh_token: bool = False,
748 auth_provider_session_id: Optional[str] = None,
743749 ) -> Tuple[str, str, Optional[int], Optional[str]]:
744750 """Register a device for a user and generate an access token.
745751
750756 device_id: The device ID to check, or None to generate a new one.
751757 initial_display_name: An optional display name for the device.
752758 is_guest: Whether this is a guest account
753 auth_provider_id: The SSO IdP the user used, if any (just used for the
754 prometheus metrics).
759 auth_provider_id: The SSO IdP the user used, if any.
755760 should_issue_refresh_token: Whether it should also issue a refresh token
761 auth_provider_session_id: The session ID received during login from the SSO IdP.
756762 Returns:
757763 Tuple of device ID, access token, access token expiration time and refresh token
758764 """
763769 is_guest=is_guest,
764770 is_appservice_ghost=is_appservice_ghost,
765771 should_issue_refresh_token=should_issue_refresh_token,
772 auth_provider_id=auth_provider_id,
773 auth_provider_session_id=auth_provider_session_id,
766774 )
767775
768776 login_counter.labels(
785793 is_guest: bool = False,
786794 is_appservice_ghost: bool = False,
787795 should_issue_refresh_token: bool = False,
796 auth_provider_id: Optional[str] = None,
797 auth_provider_session_id: Optional[str] = None,
788798 ) -> LoginDict:
789799 """Helper for register_device
790800
792802 class and RegisterDeviceReplicationServlet.
793803 """
794804 assert not self.hs.config.worker.worker_app
795 valid_until_ms = None
805 now_ms = self.clock.time_msec()
806 access_token_expiry = None
796807 if self.session_lifetime is not None:
797808 if is_guest:
798809 raise Exception(
799810 "session_lifetime is not currently implemented for guest access"
800811 )
801 valid_until_ms = self.clock.time_msec() + self.session_lifetime
812 access_token_expiry = now_ms + self.session_lifetime
813
814 if self.nonrefreshable_access_token_lifetime is not None:
815 if access_token_expiry is not None:
816 # Don't allow the non-refreshable access token to outlive the
817 # session.
818 access_token_expiry = min(
819 now_ms + self.nonrefreshable_access_token_lifetime,
820 access_token_expiry,
821 )
822 else:
823 access_token_expiry = now_ms + self.nonrefreshable_access_token_lifetime
802824
803825 refresh_token = None
804826 refresh_token_id = None
805827
806828 registered_device_id = await self.device_handler.check_device_registered(
807 user_id, device_id, initial_display_name
829 user_id,
830 device_id,
831 initial_display_name,
832 auth_provider_id=auth_provider_id,
833 auth_provider_session_id=auth_provider_session_id,
808834 )
809835 if is_guest:
810 assert valid_until_ms is None
836 assert access_token_expiry is None
811837 access_token = self.macaroon_gen.generate_guest_access_token(user_id)
812838 else:
813839 if should_issue_refresh_token:
840 # A refreshable access token lifetime must be configured
841 # since we're told to issue a refresh token (the caller checks
842 # that this value is set before setting this flag).
843 assert self.refreshable_access_token_lifetime is not None
844
845 # Set the expiry time of the refreshable access token
846 access_token_expiry = now_ms + self.refreshable_access_token_lifetime
847
848 # Set the refresh token expiry time (if configured)
849 refresh_token_expiry = None
850 if self.refresh_token_lifetime is not None:
851 refresh_token_expiry = now_ms + self.refresh_token_lifetime
852
853 # Set an ultimate session expiry time (if configured)
854 ultimate_session_expiry_ts = None
855 if self.session_lifetime is not None:
856 ultimate_session_expiry_ts = now_ms + self.session_lifetime
857
858 # Also ensure that the issued tokens don't outlive the
859 # session.
860 # (It would be weird to configure a homeserver with a shorter
861 # session lifetime than token lifetime, but may as well handle
862 # it.)
863 access_token_expiry = min(
864 access_token_expiry, ultimate_session_expiry_ts
865 )
866 if refresh_token_expiry is not None:
867 refresh_token_expiry = min(
868 refresh_token_expiry, ultimate_session_expiry_ts
869 )
870
814871 (
815872 refresh_token,
816873 refresh_token_id,
817874 ) = await self._auth_handler.create_refresh_token_for_user_id(
818875 user_id,
819876 device_id=registered_device_id,
820 )
821 valid_until_ms = (
822 self.clock.time_msec() + self.refreshable_access_token_lifetime
877 expiry_ts=refresh_token_expiry,
878 ultimate_session_expiry_ts=ultimate_session_expiry_ts,
823879 )
824880
825881 access_token = await self._auth_handler.create_access_token_for_user_id(
826882 user_id,
827883 device_id=registered_device_id,
828 valid_until_ms=valid_until_ms,
884 valid_until_ms=access_token_expiry,
829885 is_appservice_ghost=is_appservice_ghost,
830886 refresh_token_id=refresh_token_id,
831887 )
833889 return {
834890 "device_id": registered_device_id,
835891 "access_token": access_token,
836 "valid_until_ms": valid_until_ms,
892 "valid_until_ms": access_token_expiry,
837893 "refresh_token": refresh_token,
838894 }
839895
4545 from synapse.api.errors import (
4646 AuthError,
4747 Codes,
48 HttpResponseException,
4849 LimitExceededError,
4950 NotFoundError,
5051 StoreError,
5556 from synapse.event_auth import validate_event_for_room_version
5657 from synapse.events import EventBase
5758 from synapse.events.utils import copy_power_levels_contents
59 from synapse.federation.federation_client import InvalidResponseError
60 from synapse.handlers.federation import get_domains_from_state
5861 from synapse.rest.admin._base import assert_user_is_admin
5962 from synapse.storage.state import StateFilter
6063 from synapse.streams import EventSource
12171220 ).to_string(self.store)
12181221
12191222 return results
1223
1224
1225 class TimestampLookupHandler:
1226 def __init__(self, hs: "HomeServer"):
1227 self.server_name = hs.hostname
1228 self.store = hs.get_datastore()
1229 self.state_handler = hs.get_state_handler()
1230 self.federation_client = hs.get_federation_client()
1231
1232 async def get_event_for_timestamp(
1233 self,
1234 requester: Requester,
1235 room_id: str,
1236 timestamp: int,
1237 direction: str,
1238 ) -> Tuple[str, int]:
1239 """Find the closest event to the given timestamp in the given direction.
1240 If we can't find an event locally or the event we have locally is next to a gap,
1241 it will ask other federated homeservers for an event.
1242
1243 Args:
1244 requester: The user making the request according to the access token
1245 room_id: Room to fetch the event from
1246 timestamp: The point in time (inclusive) we should navigate from in
1247 the given direction to find the closest event.
1248 direction: ["f"|"b"] to indicate whether we should navigate forward
1249 or backward from the given timestamp to find the closest event.
1250
1251 Returns:
1252 A tuple containing the `event_id` closest to the given timestamp in
1253 the given direction and the `origin_server_ts`.
1254
1255 Raises:
1256 SynapseError if unable to find any event locally in the given direction
1257 """
1258
1259 local_event_id = await self.store.get_event_id_for_timestamp(
1260 room_id, timestamp, direction
1261 )
1262 logger.debug(
1263 "get_event_for_timestamp: locally, we found event_id=%s closest to timestamp=%s",
1264 local_event_id,
1265 timestamp,
1266 )
1267
1268 # Check for gaps in the history where events could be hiding in between
1269 # the timestamp given and the event we were able to find locally
1270 is_event_next_to_backward_gap = False
1271 is_event_next_to_forward_gap = False
1272 if local_event_id:
1273 local_event = await self.store.get_event(
1274 local_event_id, allow_none=False, allow_rejected=False
1275 )
1276
1277 if direction == "f":
1278 # We only need to check for a backward gap if we're looking forwards
1279 # to ensure there is nothing in between.
1280 is_event_next_to_backward_gap = (
1281 await self.store.is_event_next_to_backward_gap(local_event)
1282 )
1283 elif direction == "b":
1284 # We only need to check for a forward gap if we're looking backwards
1285 # to ensure there is nothing in between
1286 is_event_next_to_forward_gap = (
1287 await self.store.is_event_next_to_forward_gap(local_event)
1288 )
1289
1290 # If we found a gap, we should probably ask another homeserver first
1291 # about more history in between
1292 if (
1293 not local_event_id
1294 or is_event_next_to_backward_gap
1295 or is_event_next_to_forward_gap
1296 ):
1297 logger.debug(
1298 "get_event_for_timestamp: locally, we found event_id=%s closest to timestamp=%s which is next to a gap in event history so we're asking other homeservers first",
1299 local_event_id,
1300 timestamp,
1301 )
1302
1303 # Find other homeservers from the given state in the room
1304 curr_state = await self.state_handler.get_current_state(room_id)
1305 curr_domains = get_domains_from_state(curr_state)
1306 likely_domains = [
1307 domain for domain, depth in curr_domains if domain != self.server_name
1308 ]
1309
1310 # Loop through each homeserver candidate until we get a succesful response
1311 for domain in likely_domains:
1312 try:
1313 remote_response = await self.federation_client.timestamp_to_event(
1314 domain, room_id, timestamp, direction
1315 )
1316 logger.debug(
1317 "get_event_for_timestamp: response from domain(%s)=%s",
1318 domain,
1319 remote_response,
1320 )
1321
1322 # TODO: Do we want to persist this as an extremity?
1323 # TODO: I think ideally, we would try to backfill from
1324 # this event and run this whole
1325 # `get_event_for_timestamp` function again to make sure
1326 # they didn't give us an event from their gappy history.
1327 remote_event_id = remote_response.event_id
1328 origin_server_ts = remote_response.origin_server_ts
1329
1330 # Only return the remote event if it's closer than the local event
1331 if not local_event or (
1332 abs(origin_server_ts - timestamp)
1333 < abs(local_event.origin_server_ts - timestamp)
1334 ):
1335 return remote_event_id, origin_server_ts
1336 except (HttpResponseException, InvalidResponseError) as ex:
1337 # Let's not put a high priority on some other homeserver
1338 # failing to respond or giving a random response
1339 logger.debug(
1340 "Failed to fetch /timestamp_to_event from %s because of exception(%s) %s args=%s",
1341 domain,
1342 type(ex).__name__,
1343 ex,
1344 ex.args,
1345 )
1346 except Exception as ex:
1347 # But we do want to see some exceptions in our code
1348 logger.warning(
1349 "Failed to fetch /timestamp_to_event from %s because of exception(%s) %s args=%s",
1350 domain,
1351 type(ex).__name__,
1352 ex,
1353 ex.args,
1354 )
1355
1356 if not local_event_id:
1357 raise SynapseError(
1358 404,
1359 "Unable to find event from %s in direction %s" % (timestamp, direction),
1360 errcode=Codes.NOT_FOUND,
1361 )
1362
1363 return local_event_id, local_event.origin_server_ts
12201364
12211365
12221366 class RoomEventSource(EventSource[RoomStreamToken, EventBase]):
13901534 await self.store.block_room(room_id, requester_user_id)
13911535
13921536 if not await self.store.get_room(room_id):
1393 if block:
1394 # We allow you to block an unknown room.
1395 return {
1396 "kicked_users": [],
1397 "failed_to_kick_users": [],
1398 "local_aliases": [],
1399 "new_room_id": None,
1400 }
1401 else:
1402 # But if you don't want to preventatively block another room,
1403 # this function can't do anything useful.
1404 raise NotFoundError(
1405 "Cannot shut down room: unknown room id %s" % (room_id,)
1406 )
1537 # if we don't know about the room, there is nothing left to do.
1538 return {
1539 "kicked_users": [],
1540 "failed_to_kick_users": [],
1541 "local_aliases": [],
1542 "new_room_id": None,
1543 }
14071544
14081545 if new_room_user_id is not None:
14091546 if not self.hs.is_mine_id(new_room_user_id):
3535 SynapseError,
3636 UnsupportedRoomVersionError,
3737 )
38 from synapse.api.ratelimiting import Ratelimiter
3839 from synapse.events import EventBase
39 from synapse.types import JsonDict
40 from synapse.types import JsonDict, Requester
4041 from synapse.util.caches.response_cache import ResponseCache
4142
4243 if TYPE_CHECKING:
9293 self._event_serializer = hs.get_event_client_serializer()
9394 self._server_name = hs.hostname
9495 self._federation_client = hs.get_federation_client()
96 self._ratelimiter = Ratelimiter(
97 store=self._store, clock=hs.get_clock(), rate_hz=5, burst_count=10
98 )
9599
96100 # If a user tries to fetch the same page multiple times in quick succession,
97101 # only process the first attempt and return its result to subsequent requests.
248252
249253 async def get_room_hierarchy(
250254 self,
251 requester: str,
255 requester: Requester,
252256 requested_room_id: str,
253257 suggested_only: bool = False,
254258 max_depth: Optional[int] = None,
275279 Returns:
276280 The JSON hierarchy dictionary.
277281 """
282 await self._ratelimiter.ratelimit(requester)
283
278284 # If a user tries to fetch the same page multiple times in quick succession,
279285 # only process the first attempt and return its result to subsequent requests.
280286 #
282288 # to process multiple requests for the same page will result in errors.
283289 return await self._pagination_response_cache.wrap(
284290 (
285 requester,
291 requester.user.to_string(),
286292 requested_room_id,
287293 suggested_only,
288294 max_depth,
290296 from_token,
291297 ),
292298 self._get_room_hierarchy,
293 requester,
299 requester.user.to_string(),
294300 requested_room_id,
295301 suggested_only,
296302 max_depth,
364364 sso_to_matrix_id_mapper: Callable[[int], Awaitable[UserAttributes]],
365365 grandfather_existing_users: Callable[[], Awaitable[Optional[str]]],
366366 extra_login_attributes: Optional[JsonDict] = None,
367 auth_provider_session_id: Optional[str] = None,
367368 ) -> None:
368369 """
369370 Given an SSO ID, retrieve the user ID for it and possibly register the user.
413414
414415 extra_login_attributes: An optional dictionary of extra
415416 attributes to be provided to the client in the login response.
417
418 auth_provider_session_id: An optional session ID from the IdP.
416419
417420 Raises:
418421 MappingException if there was a problem mapping the response to a user.
489492 client_redirect_url,
490493 extra_login_attributes,
491494 new_user=new_user,
495 auth_provider_session_id=auth_provider_session_id,
492496 )
493497
494498 async def _call_attribute_mapper(
333333 full_state: bool,
334334 cache_context: ResponseCacheContext[SyncRequestKey],
335335 ) -> SyncResult:
336 """The start of the machinery that produces a /sync response.
337
338 See https://spec.matrix.org/v1.1/client-server-api/#syncing for full details.
339
340 This method does high-level bookkeeping:
341 - tracking the kind of sync in the logging context
342 - deleting any to_device messages whose delivery has been acknowledged.
343 - deciding if we should dispatch an instant or delayed response
344 - marking the sync as being lazily loaded, if appropriate
345
346 Computing the body of the response begins in the next method,
347 `current_sync_for_user`.
348 """
336349 if since_token is None:
337350 sync_type = "initial_sync"
338351 elif full_state:
362375 sync_config, since_token, full_state=full_state
363376 )
364377 else:
365
378 # Otherwise, we wait for something to happen and report it to the user.
366379 async def current_sync_callback(
367380 before_token: StreamToken, after_token: StreamToken
368381 ) -> SyncResult:
401414 since_token: Optional[StreamToken] = None,
402415 full_state: bool = False,
403416 ) -> SyncResult:
404 """Get the sync for client needed to match what the server has now."""
417 """Generates the response body of a sync result, represented as a SyncResult.
418
419 This is a wrapper around `generate_sync_result` which starts an open tracing
420 span to track the sync. See `generate_sync_result` for the next part of your
421 indoctrination.
422 """
405423 with start_active_span("current_sync_for_user"):
406424 log_kv({"since_token": since_token})
407425 sync_result = await self.generate_sync_result(
559577 # that have happened since `since_key` up to `end_key`, so we
560578 # can just use `get_room_events_stream_for_room`.
561579 # Otherwise, we want to return the last N events in the room
562 # in toplogical ordering.
580 # in topological ordering.
563581 if since_key:
564582 events, end_key = await self.store.get_room_events_stream_for_room(
565583 room_id,
10411059 since_token: Optional[StreamToken] = None,
10421060 full_state: bool = False,
10431061 ) -> SyncResult:
1044 """Generates a sync result."""
1062 """Generates the response body of a sync result.
1063
1064 This is represented by a `SyncResult` struct, which is built from small pieces
1065 using a `SyncResultBuilder`. See also
1066 https://spec.matrix.org/v1.1/client-server-api/#get_matrixclientv3sync
1067 the `sync_result_builder` is passed as a mutable ("inout") parameter to various
1068 helper functions. These retrieve and process the data which forms the sync body,
1069 often writing to the `sync_result_builder` to store their output.
1070
1071 At the end, we transfer data from the `sync_result_builder` to a new `SyncResult`
1072 instance to signify that the sync calculation is complete.
1073 """
10451074 # NB: The now_token gets changed by some of the generate_sync_* methods,
10461075 # this is due to some of the underlying streams not supporting the ability
10471076 # to query up to a given point.
13431372 async def _generate_sync_entry_for_account_data(
13441373 self, sync_result_builder: "SyncResultBuilder"
13451374 ) -> Dict[str, Dict[str, JsonDict]]:
1346 """Generates the account data portion of the sync response. Populates
1347 `sync_result_builder` with the result.
1375 """Generates the account data portion of the sync response.
1376
1377 Account data (called "Client Config" in the spec) can be set either globally
1378 or for a specific room. Account data consists of a list of events which
1379 accumulate state, much like a room.
1380
1381 This function retrieves global and per-room account data. The former is written
1382 to the given `sync_result_builder`. The latter is returned directly, to be
1383 later written to the `sync_result_builder` on a room-by-room basis.
13481384
13491385 Args:
13501386 sync_result_builder
13511387
13521388 Returns:
1353 A dictionary containing the per room account data.
1389 A dictionary whose keys (room ids) map to the per room account data for that
1390 room.
13541391 """
13551392 sync_config = sync_result_builder.sync_config
13561393 user_id = sync_result_builder.sync_config.user.to_string()
13581395
13591396 if since_token and not sync_result_builder.full_state:
13601397 (
1361 account_data,
1398 global_account_data,
13621399 account_data_by_room,
13631400 ) = await self.store.get_updated_account_data_for_user(
13641401 user_id, since_token.account_data_key
13691406 )
13701407
13711408 if push_rules_changed:
1372 account_data["m.push_rules"] = await self.push_rules_for_user(
1409 global_account_data["m.push_rules"] = await self.push_rules_for_user(
13731410 sync_config.user
13741411 )
13751412 else:
13761413 (
1377 account_data,
1414 global_account_data,
13781415 account_data_by_room,
13791416 ) = await self.store.get_account_data_for_user(sync_config.user.to_string())
13801417
1381 account_data["m.push_rules"] = await self.push_rules_for_user(
1418 global_account_data["m.push_rules"] = await self.push_rules_for_user(
13821419 sync_config.user
13831420 )
13841421
13851422 account_data_for_user = await sync_config.filter_collection.filter_account_data(
13861423 [
13871424 {"type": account_data_type, "content": content}
1388 for account_data_type, content in account_data.items()
1425 for account_data_type, content in global_account_data.items()
13891426 ]
13901427 )
13911428
14591496 """Generates the rooms portion of the sync response. Populates the
14601497 `sync_result_builder` with the result.
14611498
1499 In the response that reaches the client, rooms are divided into four categories:
1500 `invite`, `join`, `knock`, `leave`. These aren't the same as the four sets of
1501 room ids returned by this function.
1502
14621503 Args:
14631504 sync_result_builder
14641505 account_data_by_room: Dictionary of per room account data
14651506
14661507 Returns:
1467 Returns a 4-tuple of
1468 `(newly_joined_rooms, newly_joined_or_invited_users,
1469 newly_left_rooms, newly_left_users)`
1470 """
1508 Returns a 4-tuple describing rooms the user has joined or left, and users who've
1509 joined or left rooms any rooms the user is in. This gets used later in
1510 `_generate_sync_entry_for_device_list`.
1511
1512 Its entries are:
1513 - newly_joined_rooms
1514 - newly_joined_or_invited_or_knocked_users
1515 - newly_left_rooms
1516 - newly_left_users
1517 """
1518 since_token = sync_result_builder.since_token
1519
1520 # 1. Start by fetching all ephemeral events in rooms we've joined (if required).
14711521 user_id = sync_result_builder.sync_config.user.to_string()
14721522 block_all_room_ephemeral = (
1473 sync_result_builder.since_token is None
1523 since_token is None
14741524 and sync_result_builder.sync_config.filter_collection.blocks_all_room_ephemeral()
14751525 )
14761526
14841534 )
14851535 sync_result_builder.now_token = now_token
14861536
1487 # We check up front if anything has changed, if it hasn't then there is
1537 # 2. We check up front if anything has changed, if it hasn't then there is
14881538 # no point in going further.
1489 since_token = sync_result_builder.since_token
14901539 if not sync_result_builder.full_state:
14911540 if since_token and not ephemeral_by_room and not account_data_by_room:
14921541 have_changed = await self._have_rooms_changed(sync_result_builder)
14991548 logger.debug("no-oping sync")
15001549 return set(), set(), set(), set()
15011550
1502 ignored_account_data = (
1503 await self.store.get_global_account_data_by_type_for_user(
1504 AccountDataTypes.IGNORED_USER_LIST, user_id=user_id
1505 )
1506 )
1507
1508 # If there is ignored users account data and it matches the proper type,
1509 # then use it.
1510 ignored_users: FrozenSet[str] = frozenset()
1511 if ignored_account_data:
1512 ignored_users_data = ignored_account_data.get("ignored_users", {})
1513 if isinstance(ignored_users_data, dict):
1514 ignored_users = frozenset(ignored_users_data.keys())
1515
1551 # 3. Work out which rooms need reporting in the sync response.
1552 ignored_users = await self._get_ignored_users(user_id)
15161553 if since_token:
15171554 room_changes = await self._get_rooms_changed(
15181555 sync_result_builder, ignored_users
15221559 )
15231560 else:
15241561 room_changes = await self._get_all_rooms(sync_result_builder, ignored_users)
1525
15261562 tags_by_room = await self.store.get_tags_for_user(user_id)
15271563
15281564 log_kv({"rooms_changed": len(room_changes.room_entries)})
15331569 newly_joined_rooms = room_changes.newly_joined_rooms
15341570 newly_left_rooms = room_changes.newly_left_rooms
15351571
1572 # 4. We need to apply further processing to `room_entries` (rooms considered
1573 # joined or archived).
15361574 async def handle_room_entries(room_entry: "RoomSyncResultBuilder") -> None:
15371575 logger.debug("Generating room entry for %s", room_entry.room_id)
15381576 await self._generate_room_entry(
15511589 sync_result_builder.invited.extend(invited)
15521590 sync_result_builder.knocked.extend(knocked)
15531591
1554 # Now we want to get any newly joined, invited or knocking users
1555 newly_joined_or_invited_or_knocked_users = set()
1556 newly_left_users = set()
1557 if since_token:
1558 for joined_sync in sync_result_builder.joined:
1559 it = itertools.chain(
1560 joined_sync.timeline.events, joined_sync.state.values()
1561 )
1562 for event in it:
1563 if event.type == EventTypes.Member:
1564 if (
1565 event.membership == Membership.JOIN
1566 or event.membership == Membership.INVITE
1567 or event.membership == Membership.KNOCK
1568 ):
1569 newly_joined_or_invited_or_knocked_users.add(
1570 event.state_key
1571 )
1572 else:
1573 prev_content = event.unsigned.get("prev_content", {})
1574 prev_membership = prev_content.get("membership", None)
1575 if prev_membership == Membership.JOIN:
1576 newly_left_users.add(event.state_key)
1577
1578 newly_left_users -= newly_joined_or_invited_or_knocked_users
1592 # 5. Work out which users have joined or left rooms we're in. We use this
1593 # to build the device_list part of the sync response in
1594 # `_generate_sync_entry_for_device_list`.
1595 (
1596 newly_joined_or_invited_or_knocked_users,
1597 newly_left_users,
1598 ) = sync_result_builder.calculate_user_changes()
15791599
15801600 return (
15811601 set(newly_joined_rooms),
15841604 newly_left_users,
15851605 )
15861606
1607 async def _get_ignored_users(self, user_id: str) -> FrozenSet[str]:
1608 """Retrieve the users ignored by the given user from their global account_data.
1609
1610 Returns an empty set if
1611 - there is no global account_data entry for ignored_users
1612 - there is such an entry, but it's not a JSON object.
1613 """
1614 # TODO: Can we `SELECT ignored_user_id FROM ignored_users WHERE ignorer_user_id=?;` instead?
1615 ignored_account_data = (
1616 await self.store.get_global_account_data_by_type_for_user(
1617 AccountDataTypes.IGNORED_USER_LIST, user_id=user_id
1618 )
1619 )
1620
1621 # If there is ignored users account data and it matches the proper type,
1622 # then use it.
1623 ignored_users: FrozenSet[str] = frozenset()
1624 if ignored_account_data:
1625 ignored_users_data = ignored_account_data.get("ignored_users", {})
1626 if isinstance(ignored_users_data, dict):
1627 ignored_users = frozenset(ignored_users_data.keys())
1628 return ignored_users
1629
15871630 async def _have_rooms_changed(
15881631 self, sync_result_builder: "SyncResultBuilder"
15891632 ) -> bool:
15901633 """Returns whether there may be any new events that should be sent down
15911634 the sync. Returns True if there are.
1635
1636 Does not modify the `sync_result_builder`.
15921637 """
15931638 user_id = sync_result_builder.sync_config.user.to_string()
15941639 since_token = sync_result_builder.since_token
15961641
15971642 assert since_token
15981643
1599 # Get a list of membership change events that have happened.
1600 rooms_changed = await self.store.get_membership_changes_for_user(
1644 # Get a list of membership change events that have happened to the user
1645 # requesting the sync.
1646 membership_changes = await self.store.get_membership_changes_for_user(
16011647 user_id, since_token.room_key, now_token.room_key
16021648 )
16031649
1604 if rooms_changed:
1650 if membership_changes:
16051651 return True
16061652
16071653 stream_id = since_token.room_key.stream
16131659 async def _get_rooms_changed(
16141660 self, sync_result_builder: "SyncResultBuilder", ignored_users: FrozenSet[str]
16151661 ) -> _RoomChanges:
1616 """Gets the the changes that have happened since the last sync."""
1662 """Determine the changes in rooms to report to the user.
1663
1664 Ideally, we want to report all events whose stream ordering `s` lies in the
1665 range `since_token < s <= now_token`, where the two tokens are read from the
1666 sync_result_builder.
1667
1668 If there are too many events in that range to report, things get complicated.
1669 In this situation we return a truncated list of the most recent events, and
1670 indicate in the response that there is a "gap" of omitted events. Additionally:
1671
1672 - we include a "state_delta", to describe the changes in state over the gap,
1673 - we include all membership events applying to the user making the request,
1674 even those in the gap.
1675
1676 See the spec for the rationale:
1677 https://spec.matrix.org/v1.1/client-server-api/#syncing
1678
1679 The sync_result_builder is not modified by this function.
1680 """
16171681 user_id = sync_result_builder.sync_config.user.to_string()
16181682 since_token = sync_result_builder.since_token
16191683 now_token = sync_result_builder.now_token
16211685
16221686 assert since_token
16231687
1624 # Get a list of membership change events that have happened.
1625 rooms_changed = await self.store.get_membership_changes_for_user(
1688 # The spec
1689 # https://spec.matrix.org/v1.1/client-server-api/#get_matrixclientv3sync
1690 # notes that membership events need special consideration:
1691 #
1692 # > When a sync is limited, the server MUST return membership events for events
1693 # > in the gap (between since and the start of the returned timeline), regardless
1694 # > as to whether or not they are redundant.
1695 #
1696 # We fetch such events here, but we only seem to use them for categorising rooms
1697 # as newly joined, newly left, invited or knocked.
1698 # TODO: we've already called this function and ran this query in
1699 # _have_rooms_changed. We could keep the results in memory to avoid a
1700 # second query, at the cost of more complicated source code.
1701 membership_change_events = await self.store.get_membership_changes_for_user(
16261702 user_id, since_token.room_key, now_token.room_key
16271703 )
16281704
16291705 mem_change_events_by_room_id: Dict[str, List[EventBase]] = {}
1630 for event in rooms_changed:
1706 for event in membership_change_events:
16311707 mem_change_events_by_room_id.setdefault(event.room_id, []).append(event)
16321708
1633 newly_joined_rooms = []
1634 newly_left_rooms = []
1635 room_entries = []
1636 invited = []
1637 knocked = []
1709 newly_joined_rooms: List[str] = []
1710 newly_left_rooms: List[str] = []
1711 room_entries: List[RoomSyncResultBuilder] = []
1712 invited: List[InvitedSyncResult] = []
1713 knocked: List[KnockedSyncResult] = []
16381714 for room_id, events in mem_change_events_by_room_id.items():
1715 # The body of this loop will add this room to at least one of the five lists
1716 # above. Things get messy if you've e.g. joined, left, joined then left the
1717 # room all in the same sync period.
16391718 logger.debug(
16401719 "Membership changes in %s: [%s]",
16411720 room_id,
16901769
16911770 if not non_joins:
16921771 continue
1772 last_non_join = non_joins[-1]
16931773
16941774 # Check if we have left the room. This can either be because we were
16951775 # joined before *or* that we since joined and then left.
17111791 newly_left_rooms.append(room_id)
17121792
17131793 # Only bother if we're still currently invited
1714 should_invite = non_joins[-1].membership == Membership.INVITE
1794 should_invite = last_non_join.membership == Membership.INVITE
17151795 if should_invite:
1716 if event.sender not in ignored_users:
1717 invite_room_sync = InvitedSyncResult(room_id, invite=non_joins[-1])
1796 if last_non_join.sender not in ignored_users:
1797 invite_room_sync = InvitedSyncResult(room_id, invite=last_non_join)
17181798 if invite_room_sync:
17191799 invited.append(invite_room_sync)
17201800
17211801 # Only bother if our latest membership in the room is knock (and we haven't
17221802 # been accepted/rejected in the meantime).
1723 should_knock = non_joins[-1].membership == Membership.KNOCK
1803 should_knock = last_non_join.membership == Membership.KNOCK
17241804 if should_knock:
1725 knock_room_sync = KnockedSyncResult(room_id, knock=non_joins[-1])
1805 knock_room_sync = KnockedSyncResult(room_id, knock=last_non_join)
17261806 if knock_room_sync:
17271807 knocked.append(knock_room_sync)
17281808
17801860
17811861 timeline_limit = sync_config.filter_collection.timeline_limit()
17821862
1783 # Get all events for rooms we're currently joined to.
1863 # Get all events since the `from_key` in rooms we're currently joined to.
1864 # If there are too many, we get the most recent events only. This leaves
1865 # a "gap" in the timeline, as described by the spec for /sync.
17841866 room_to_events = await self.store.get_room_events_stream_for_rooms(
17851867 room_ids=sync_result_builder.joined_room_ids,
17861868 from_key=since_token.room_key,
18411923 ) -> _RoomChanges:
18421924 """Returns entries for all rooms for the user.
18431925
1926 Like `_get_rooms_changed`, but assumes the `since_token` is `None`.
1927
1928 This function does not modify the sync_result_builder.
1929
18441930 Args:
18451931 sync_result_builder
18461932 ignored_users: Set of users ignored by user.
18521938 now_token = sync_result_builder.now_token
18531939 sync_config = sync_result_builder.sync_config
18541940
1855 membership_list = (
1856 Membership.INVITE,
1857 Membership.KNOCK,
1858 Membership.JOIN,
1859 Membership.LEAVE,
1860 Membership.BAN,
1861 )
1862
18631941 room_list = await self.store.get_rooms_for_local_user_where_membership_is(
1864 user_id=user_id, membership_list=membership_list
1942 user_id=user_id,
1943 membership_list=Membership.LIST,
18651944 )
18661945
18671946 room_entries = []
22112290 # to only include membership events for the senders in the timeline.
22122291 # In practice, we can do this by removing them from the p_ids list,
22132292 # which is the list of relevant state we know we have already sent to the client.
2214 # see https://github.com/matrix-org/synapse/pull/2970
2215 # /files/efcdacad7d1b7f52f879179701c7e0d9b763511f#r204732809
2293 # see https://github.com/matrix-org/synapse/pull/2970/files/efcdacad7d1b7f52f879179701c7e0d9b763511f#r204732809
22162294
22172295 if lazy_load_members:
22182296 p_ids.difference_update(
22612339 groups: Optional[GroupsSyncResult] = None
22622340 to_device: List[JsonDict] = attr.Factory(list)
22632341
2342 def calculate_user_changes(self) -> Tuple[Set[str], Set[str]]:
2343 """Work out which other users have joined or left rooms we are joined to.
2344
2345 This data only is only useful for an incremental sync.
2346
2347 The SyncResultBuilder is not modified by this function.
2348 """
2349 newly_joined_or_invited_or_knocked_users = set()
2350 newly_left_users = set()
2351 if self.since_token:
2352 for joined_sync in self.joined:
2353 it = itertools.chain(
2354 joined_sync.timeline.events, joined_sync.state.values()
2355 )
2356 for event in it:
2357 if event.type == EventTypes.Member:
2358 if (
2359 event.membership == Membership.JOIN
2360 or event.membership == Membership.INVITE
2361 or event.membership == Membership.KNOCK
2362 ):
2363 newly_joined_or_invited_or_knocked_users.add(
2364 event.state_key
2365 )
2366 else:
2367 prev_content = event.unsigned.get("prev_content", {})
2368 prev_membership = prev_content.get("membership", None)
2369 if prev_membership == Membership.JOIN:
2370 newly_left_users.add(event.state_key)
2371
2372 newly_left_users -= newly_joined_or_invited_or_knocked_users
2373 return newly_joined_or_invited_or_knocked_users, newly_left_users
2374
22642375
22652376 @attr.s(slots=True, auto_attribs=True)
22662377 class RoomSyncResultBuilder:
7676 """
7777 args: Mapping[bytes, Sequence[bytes]] = request.args # type: ignore
7878 return parse_integer_from_args(args, name, default, required)
79
80
81 @overload
82 def parse_integer_from_args(
83 args: Mapping[bytes, Sequence[bytes]],
84 name: str,
85 default: Optional[int] = None,
86 ) -> Optional[int]:
87 ...
88
89
90 @overload
91 def parse_integer_from_args(
92 args: Mapping[bytes, Sequence[bytes]],
93 name: str,
94 *,
95 required: Literal[True],
96 ) -> int:
97 ...
98
99
100 @overload
101 def parse_integer_from_args(
102 args: Mapping[bytes, Sequence[bytes]],
103 name: str,
104 default: Optional[int] = None,
105 required: bool = False,
106 ) -> Optional[int]:
107 ...
79108
80109
81110 def parse_integer_from_args(
2323 List,
2424 Optional,
2525 Tuple,
26 TypeVar,
2627 Union,
2728 )
2829
8081 )
8182 from synapse.http.servlet import parse_json_object_from_request
8283 from synapse.http.site import SynapseRequest
83 from synapse.logging.context import make_deferred_yieldable, run_in_background
84 from synapse.logging.context import (
85 defer_to_thread,
86 make_deferred_yieldable,
87 run_in_background,
88 )
8489 from synapse.metrics.background_process_metrics import run_as_background_process
8590 from synapse.rest.client.login import LoginResponse
8691 from synapse.storage import DataStore
92 from synapse.storage.background_updates import (
93 DEFAULT_BATCH_SIZE_CALLBACK,
94 MIN_BATCH_SIZE_CALLBACK,
95 ON_UPDATE_CALLBACK,
96 )
8797 from synapse.storage.database import DatabasePool, LoggingTransaction
8898 from synapse.storage.databases.main.roommember import ProfileInfo
8999 from synapse.storage.state import StateFilter
97107 create_requester,
98108 )
99109 from synapse.util import Clock
110 from synapse.util.async_helpers import maybe_awaitable
100111 from synapse.util.caches.descriptors import cached
101112
102113 if TYPE_CHECKING:
103114 from synapse.app.generic_worker import GenericWorkerSlavedStore
104115 from synapse.server import HomeServer
116
117
118 T = TypeVar("T")
105119
106120 """
107121 This package defines the 'stable' API which can be used by extension modules which
306320 auth_checkers=auth_checkers,
307321 )
308322
309 def register_web_resource(self, path: str, resource: Resource):
323 def register_background_update_controller_callbacks(
324 self,
325 on_update: ON_UPDATE_CALLBACK,
326 default_batch_size: Optional[DEFAULT_BATCH_SIZE_CALLBACK] = None,
327 min_batch_size: Optional[MIN_BATCH_SIZE_CALLBACK] = None,
328 ) -> None:
329 """Registers background update controller callbacks.
330
331 Added in Synapse v1.49.0.
332 """
333
334 for db in self._hs.get_datastores().databases:
335 db.updates.register_update_controller_callbacks(
336 on_update=on_update,
337 default_batch_size=default_batch_size,
338 min_batch_size=min_batch_size,
339 )
340
341 def register_web_resource(self, path: str, resource: Resource) -> None:
310342 """Registers a web resource to be served at the given path.
311343
312344 This function should be called during initialisation of the module.
431463 username: provided user id
432464
433465 Returns:
434 str: qualified @user:id
466 qualified @user:id
435467 """
436468 if username.startswith("@"):
437469 return username
467499 """
468500 return await self._store.user_get_threepids(user_id)
469501
470 def check_user_exists(self, user_id: str):
502 def check_user_exists(self, user_id: str) -> "defer.Deferred[Optional[str]]":
471503 """Check if user exists.
472504
473505 Added in Synapse v0.25.0.
476508 user_id: Complete @user:id
477509
478510 Returns:
479 Deferred[str|None]: Canonical (case-corrected) user_id, or None
511 Canonical (case-corrected) user_id, or None
480512 if the user is not registered.
481513 """
482514 return defer.ensureDeferred(self._auth_handler.check_user_exists(user_id))
483515
484516 @defer.inlineCallbacks
485 def register(self, localpart, displayname=None, emails: Optional[List[str]] = None):
517 def register(
518 self,
519 localpart: str,
520 displayname: Optional[str] = None,
521 emails: Optional[List[str]] = None,
522 ) -> Generator["defer.Deferred[Any]", Any, Tuple[str, str]]:
486523 """Registers a new user with given localpart and optional displayname, emails.
487524
488525 Also returns an access token for the new user.
494531 Added in Synapse v0.25.0.
495532
496533 Args:
497 localpart (str): The localpart of the new user.
498 displayname (str|None): The displayname of the new user.
499 emails (List[str]): Emails to bind to the new user.
500
501 Returns:
502 Deferred[tuple[str, str]]: a 2-tuple of (user_id, access_token)
534 localpart: The localpart of the new user.
535 displayname: The displayname of the new user.
536 emails: Emails to bind to the new user.
537
538 Returns:
539 a 2-tuple of (user_id, access_token)
503540 """
504541 logger.warning(
505542 "Using deprecated ModuleApi.register which creates a dummy user device."
509546 return user_id, access_token
510547
511548 def register_user(
512 self, localpart, displayname=None, emails: Optional[List[str]] = None
513 ):
549 self,
550 localpart: str,
551 displayname: Optional[str] = None,
552 emails: Optional[List[str]] = None,
553 ) -> "defer.Deferred[str]":
514554 """Registers a new user with given localpart and optional displayname, emails.
515555
516556 Added in Synapse v1.2.0.
517557
518558 Args:
519 localpart (str): The localpart of the new user.
520 displayname (str|None): The displayname of the new user.
521 emails (List[str]): Emails to bind to the new user.
559 localpart: The localpart of the new user.
560 displayname: The displayname of the new user.
561 emails: Emails to bind to the new user.
522562
523563 Raises:
524564 SynapseError if there is an error performing the registration. Check the
525565 'errcode' property for more information on the reason for failure
526566
527567 Returns:
528 defer.Deferred[str]: user_id
568 user_id
529569 """
530570 return defer.ensureDeferred(
531571 self._hs.get_registration_handler().register_user(
535575 )
536576 )
537577
538 def register_device(self, user_id, device_id=None, initial_display_name=None):
578 def register_device(
579 self,
580 user_id: str,
581 device_id: Optional[str] = None,
582 initial_display_name: Optional[str] = None,
583 ) -> "defer.Deferred[Tuple[str, str, Optional[int], Optional[str]]]":
539584 """Register a device for a user and generate an access token.
540585
541586 Added in Synapse v1.2.0.
542587
543588 Args:
544 user_id (str): full canonical @user:id
545 device_id (str|None): The device ID to check, or None to generate
589 user_id: full canonical @user:id
590 device_id: The device ID to check, or None to generate
546591 a new one.
547 initial_display_name (str|None): An optional display name for the
592 initial_display_name: An optional display name for the
548593 device.
549594
550595 Returns:
551 defer.Deferred[tuple[str, str]]: Tuple of device ID and access token
596 Tuple of device ID, access token, access token expiration time and refresh token
552597 """
553598 return defer.ensureDeferred(
554599 self._hs.get_registration_handler().register_device(
581626 user_id: str,
582627 duration_in_ms: int = (2 * 60 * 1000),
583628 auth_provider_id: str = "",
629 auth_provider_session_id: Optional[str] = None,
584630 ) -> str:
585631 """Generate a login token suitable for m.login.token authentication
586632
598644 return self._hs.get_macaroon_generator().generate_short_term_login_token(
599645 user_id,
600646 auth_provider_id,
647 auth_provider_session_id,
601648 duration_in_ms,
602649 )
603650
604651 @defer.inlineCallbacks
605 def invalidate_access_token(self, access_token):
652 def invalidate_access_token(
653 self, access_token: str
654 ) -> Generator["defer.Deferred[Any]", Any, None]:
606655 """Invalidate an access token for a user
607656
608657 Added in Synapse v0.25.0.
634683 self._auth_handler.delete_access_token(access_token)
635684 )
636685
637 def run_db_interaction(self, desc, func, *args, **kwargs):
686 def run_db_interaction(
687 self,
688 desc: str,
689 func: Callable[..., T],
690 *args: Any,
691 **kwargs: Any,
692 ) -> "defer.Deferred[T]":
638693 """Run a function with a database connection
639694
640695 Added in Synapse v0.25.0.
641696
642697 Args:
643 desc (str): description for the transaction, for metrics etc
644 func (func): function to be run. Passed a database cursor object
698 desc: description for the transaction, for metrics etc
699 func: function to be run. Passed a database cursor object
645700 as well as *args and **kwargs
646701 *args: positional args to be passed to func
647702 **kwargs: named args to be passed to func
655710
656711 def complete_sso_login(
657712 self, registered_user_id: str, request: SynapseRequest, client_redirect_url: str
658 ):
713 ) -> None:
659714 """Complete a SSO login by redirecting the user to a page to confirm whether they
660715 want their access token sent to `client_redirect_url`, or redirect them to that
661716 URL with a token directly if the URL matches with one of the whitelisted clients.
685740 client_redirect_url: str,
686741 new_user: bool = False,
687742 auth_provider_id: str = "<unknown>",
688 ):
743 ) -> None:
689744 """Complete a SSO login by redirecting the user to a page to confirm whether they
690745 want their access token sent to `client_redirect_url`, or redirect them to that
691746 URL with a token directly if the URL matches with one of the whitelisted clients.
924979 self,
925980 f: Callable,
926981 msec: float,
927 *args,
982 *args: object,
928983 desc: Optional[str] = None,
929984 run_on_all_instances: bool = False,
930 **kwargs,
931 ):
985 **kwargs: object,
986 ) -> None:
932987 """Wraps a function as a background process and calls it repeatedly.
933988
934989 NOTE: Will only run on the instance that is configured to run
9591014 run_as_background_process,
9601015 msec,
9611016 desc,
962 f,
963 *args,
964 **kwargs,
1017 lambda: maybe_awaitable(f(*args, **kwargs)),
9651018 )
9661019 else:
9671020 logger.warning(
9691022 f,
9701023 )
9711024
1025 async def sleep(self, seconds: float) -> None:
1026 """Sleeps for the given number of seconds."""
1027
1028 await self._clock.sleep(seconds)
1029
9721030 async def send_mail(
9731031 self,
9741032 recipient: str,
9751033 subject: str,
9761034 html: str,
9771035 text: str,
978 ):
1036 ) -> None:
9791037 """Send an email on behalf of the homeserver.
9801038
9811039 Added in Synapse v1.39.0.
11231181
11241182 return {key: state_events[event_id] for key, event_id in state_ids.items()}
11251183
1184 async def defer_to_thread(
1185 self,
1186 f: Callable[..., T],
1187 *args: Any,
1188 **kwargs: Any,
1189 ) -> T:
1190 """Runs the given function in a separate thread from Synapse's thread pool.
1191
1192 Added in Synapse v1.49.0.
1193
1194 Args:
1195 f: The function to run.
1196 args: The function's arguments.
1197 kwargs: The function's keyword arguments.
1198
1199 Returns:
1200 The return value of the function once ran in a thread.
1201 """
1202 return await defer_to_thread(self._hs.get_reactor(), f, *args, **kwargs)
1203
11261204
11271205 class PublicRoomListManager:
11281206 """Contains methods for adding to, removing from and querying whether a room
2020 from synapse.metrics.background_process_metrics import run_as_background_process
2121 from synapse.push import Pusher, PusherConfig, PusherConfigException, ThrottleParams
2222 from synapse.push.mailer import Mailer
23 from synapse.push.push_types import EmailReason
24 from synapse.storage.databases.main.event_push_actions import EmailPushAction
2325 from synapse.util.threepids import validate_email
2426
2527 if TYPE_CHECKING:
189191 # we then consider all previously outstanding notifications
190192 # to be delivered.
191193
192 reason = {
194 reason: EmailReason = {
193195 "room_id": push_action["room_id"],
194196 "now": self.clock.time_msec(),
195197 "received_at": received_at,
274276 return may_send_at
275277
276278 async def sent_notif_update_throttle(
277 self, room_id: str, notified_push_action: dict
279 self, room_id: str, notified_push_action: EmailPushAction
278280 ) -> None:
279281 # We have sent a notification, so update the throttle accordingly.
280282 # If the event that triggered the notif happened more than
314316 self.pusher_id, room_id, self.throttle_params[room_id]
315317 )
316318
317 async def send_notification(self, push_actions: List[dict], reason: dict) -> None:
319 async def send_notification(
320 self, push_actions: List[EmailPushAction], reason: EmailReason
321 ) -> None:
318322 logger.info("Sending notif email for user %r", self.user_id)
319323
320324 await self.mailer.send_notification_mail(
2525 from synapse.logging import opentracing
2626 from synapse.metrics.background_process_metrics import run_as_background_process
2727 from synapse.push import Pusher, PusherConfig, PusherConfigException
28 from synapse.storage.databases.main.event_push_actions import HttpPushAction
2829
2930 from . import push_rule_evaluator, push_tools
3031
272273 )
273274 break
274275
275 async def _process_one(self, push_action: dict) -> bool:
276 async def _process_one(self, push_action: HttpPushAction) -> bool:
276277 if "notify" not in push_action["actions"]:
277278 return True
278279
1313
1414 import logging
1515 import urllib.parse
16 from typing import TYPE_CHECKING, Any, Dict, Iterable, List, Optional, TypeVar
16 from typing import TYPE_CHECKING, Dict, Iterable, List, Optional, TypeVar
1717
1818 import bleach
1919 import jinja2
2727 descriptor_from_member_events,
2828 name_from_member_event,
2929 )
30 from synapse.push.push_types import (
31 EmailReason,
32 MessageVars,
33 NotifVars,
34 RoomVars,
35 TemplateVars,
36 )
37 from synapse.storage.databases.main.event_push_actions import EmailPushAction
3038 from synapse.storage.state import StateFilter
3139 from synapse.types import StateMap, UserID
3240 from synapse.util.async_helpers import concurrently_execute
134142 % urllib.parse.urlencode(params)
135143 )
136144
137 template_vars = {"link": link}
145 template_vars: TemplateVars = {"link": link}
138146
139147 await self.send_email(
140148 email_address,
164172 % urllib.parse.urlencode(params)
165173 )
166174
167 template_vars = {"link": link}
175 template_vars: TemplateVars = {"link": link}
168176
169177 await self.send_email(
170178 email_address,
195203 % urllib.parse.urlencode(params)
196204 )
197205
198 template_vars = {"link": link}
206 template_vars: TemplateVars = {"link": link}
199207
200208 await self.send_email(
201209 email_address,
209217 app_id: str,
210218 user_id: str,
211219 email_address: str,
212 push_actions: Iterable[Dict[str, Any]],
213 reason: Dict[str, Any],
220 push_actions: Iterable[EmailPushAction],
221 reason: EmailReason,
214222 ) -> None:
215223 """
216224 Send email regarding a user's room notifications
229237 [pa["event_id"] for pa in push_actions]
230238 )
231239
232 notifs_by_room: Dict[str, List[Dict[str, Any]]] = {}
240 notifs_by_room: Dict[str, List[EmailPushAction]] = {}
233241 for pa in push_actions:
234242 notifs_by_room.setdefault(pa["room_id"], []).append(pa)
235243
257265 # actually sort our so-called rooms_in_order list, most recent room first
258266 rooms_in_order.sort(key=lambda r: -(notifs_by_room[r][-1]["received_ts"] or 0))
259267
260 rooms: List[Dict[str, Any]] = []
268 rooms: List[RoomVars] = []
261269
262270 for r in rooms_in_order:
263271 roomvars = await self._get_room_vars(
288296 notifs_by_room, state_by_room, notif_events, reason
289297 )
290298
291 template_vars = {
299 template_vars: TemplateVars = {
292300 "user_display_name": user_display_name,
293301 "unsubscribe_link": self._make_unsubscribe_link(
294302 user_id, app_id, email_address
301309 await self.send_email(email_address, summary_text, template_vars)
302310
303311 async def send_email(
304 self, email_address: str, subject: str, extra_template_vars: Dict[str, Any]
312 self, email_address: str, subject: str, extra_template_vars: TemplateVars
305313 ) -> None:
306314 """Send an email with the given information and template text"""
307 template_vars = {
315 template_vars: TemplateVars = {
308316 "app_name": self.app_name,
309317 "server_name": self.hs.config.server.server_name,
310318 }
326334 self,
327335 room_id: str,
328336 user_id: str,
329 notifs: Iterable[Dict[str, Any]],
337 notifs: Iterable[EmailPushAction],
330338 notif_events: Dict[str, EventBase],
331339 room_state_ids: StateMap[str],
332 ) -> Dict[str, Any]:
340 ) -> RoomVars:
333341 """
334342 Generate the variables for notifications on a per-room basis.
335343
355363
356364 room_name = await calculate_room_name(self.store, room_state_ids, user_id)
357365
358 room_vars: Dict[str, Any] = {
366 room_vars: RoomVars = {
359367 "title": room_name,
360368 "hash": string_ordinal_total(room_id), # See sender avatar hash
361369 "notifs": [],
416424
417425 async def _get_notif_vars(
418426 self,
419 notif: Dict[str, Any],
427 notif: EmailPushAction,
420428 user_id: str,
421429 notif_event: EventBase,
422430 room_state_ids: StateMap[str],
423 ) -> Dict[str, Any]:
431 ) -> NotifVars:
424432 """
425433 Generate the variables for a single notification.
426434
441449 after_limit=CONTEXT_AFTER,
442450 )
443451
444 ret = {
452 ret: NotifVars = {
445453 "link": self._make_notif_link(notif),
446454 "ts": notif["received_ts"],
447455 "messages": [],
460468 return ret
461469
462470 async def _get_message_vars(
463 self, notif: Dict[str, Any], event: EventBase, room_state_ids: StateMap[str]
464 ) -> Optional[Dict[str, Any]]:
471 self, notif: EmailPushAction, event: EventBase, room_state_ids: StateMap[str]
472 ) -> Optional[MessageVars]:
465473 """
466474 Generate the variables for a single event, if possible.
467475
493501
494502 if sender_state_event:
495503 sender_name = name_from_member_event(sender_state_event)
496 sender_avatar_url = sender_state_event.content.get("avatar_url")
504 sender_avatar_url: Optional[str] = sender_state_event.content.get(
505 "avatar_url"
506 )
497507 else:
498508 # No state could be found, fallback to the MXID.
499509 sender_name = event.sender
503513 # sender_hash % the number of default images to choose from
504514 sender_hash = string_ordinal_total(event.sender)
505515
506 ret = {
516 ret: MessageVars = {
507517 "event_type": event.type,
508518 "is_historical": event.event_id != notif["event_id"],
509519 "id": event.event_id,
518528 return ret
519529
520530 msgtype = event.content.get("msgtype")
531 if not isinstance(msgtype, str):
532 msgtype = None
521533
522534 ret["msgtype"] = msgtype
523535
532544 return ret
533545
534546 def _add_text_message_vars(
535 self, messagevars: Dict[str, Any], event: EventBase
547 self, messagevars: MessageVars, event: EventBase
536548 ) -> None:
537549 """
538550 Potentially add a sanitised message body to the message variables.
542554 event: The event under consideration.
543555 """
544556 msgformat = event.content.get("format")
545
546 messagevars["format"] = msgformat
557 if not isinstance(msgformat, str):
558 msgformat = None
547559
548560 formatted_body = event.content.get("formatted_body")
549561 body = event.content.get("body")
554566 messagevars["body_text_html"] = safe_text(body)
555567
556568 def _add_image_message_vars(
557 self, messagevars: Dict[str, Any], event: EventBase
569 self, messagevars: MessageVars, event: EventBase
558570 ) -> None:
559571 """
560572 Potentially add an image URL to the message variables.
569581 async def _make_summary_text_single_room(
570582 self,
571583 room_id: str,
572 notifs: List[Dict[str, Any]],
584 notifs: List[EmailPushAction],
573585 room_state_ids: StateMap[str],
574586 notif_events: Dict[str, EventBase],
575587 user_id: str,
684696
685697 async def _make_summary_text(
686698 self,
687 notifs_by_room: Dict[str, List[Dict[str, Any]]],
699 notifs_by_room: Dict[str, List[EmailPushAction]],
688700 room_state_ids: Dict[str, StateMap[str]],
689701 notif_events: Dict[str, EventBase],
690 reason: Dict[str, Any],
702 reason: EmailReason,
691703 ) -> str:
692704 """
693705 Make a summary text for the email when multiple rooms have notifications.
717729 async def _make_summary_text_from_member_events(
718730 self,
719731 room_id: str,
720 notifs: List[Dict[str, Any]],
732 notifs: List[EmailPushAction],
721733 room_state_ids: StateMap[str],
722734 notif_events: Dict[str, EventBase],
723735 ) -> str:
804816 base_url = "https://matrix.to/#"
805817 return "%s/%s" % (base_url, room_id)
806818
807 def _make_notif_link(self, notif: Dict[str, str]) -> str:
819 def _make_notif_link(self, notif: EmailPushAction) -> str:
808820 """
809821 Generate a link to open an event in the web client.
810822
0 # Copyright 2021 The Matrix.org Foundation C.I.C.
1 #
2 # Licensed under the Apache License, Version 2.0 (the "License");
3 # you may not use this file except in compliance with the License.
4 # You may obtain a copy of the License at
5 #
6 # http://www.apache.org/licenses/LICENSE-2.0
7 #
8 # Unless required by applicable law or agreed to in writing, software
9 # distributed under the License is distributed on an "AS IS" BASIS,
10 # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
11 # See the License for the specific language governing permissions and
12 # limitations under the License.
13 from typing import List, Optional
14
15 from typing_extensions import TypedDict
16
17
18 class EmailReason(TypedDict, total=False):
19 """
20 Information on the event that triggered the email to be sent
21
22 room_id: the ID of the room the event was sent in
23 now: timestamp in ms when the email is being sent out
24 room_name: a human-readable name for the room the event was sent in
25 received_at: the time in milliseconds at which the event was received
26 delay_before_mail_ms: the amount of time in milliseconds Synapse always waits
27 before ever emailing about a notification (to give the user a chance to respond
28 to other push or notice the window)
29 last_sent_ts: the time in milliseconds at which a notification was last sent
30 for an event in this room
31 throttle_ms: the minimum amount of time in milliseconds between two
32 notifications can be sent for this room
33 """
34
35 room_id: str
36 now: int
37 room_name: Optional[str]
38 received_at: int
39 delay_before_mail_ms: int
40 last_sent_ts: int
41 throttle_ms: int
42
43
44 class MessageVars(TypedDict, total=False):
45 """
46 Details about a specific message to include in a notification
47
48 event_type: the type of the event
49 is_historical: a boolean, which is `False` if the message is the one
50 that triggered the notification, `True` otherwise
51 id: the ID of the event
52 ts: the time in milliseconds at which the event was sent
53 sender_name: the display name for the event's sender
54 sender_avatar_url: the avatar URL (as a `mxc://` URL) for the event's
55 sender
56 sender_hash: a hash of the user ID of the sender
57 msgtype: the type of the message
58 body_text_html: html representation of the message
59 body_text_plain: plaintext representation of the message
60 image_url: mxc url of an image, when "msgtype" is "m.image"
61 """
62
63 event_type: str
64 is_historical: bool
65 id: str
66 ts: int
67 sender_name: str
68 sender_avatar_url: Optional[str]
69 sender_hash: int
70 msgtype: Optional[str]
71 body_text_html: str
72 body_text_plain: str
73 image_url: str
74
75
76 class NotifVars(TypedDict):
77 """
78 Details about an event we are about to include in a notification
79
80 link: a `matrix.to` link to the event
81 ts: the time in milliseconds at which the event was received
82 messages: a list of messages containing one message before the event, the
83 message in the event, and one message after the event.
84 """
85
86 link: str
87 ts: Optional[int]
88 messages: List[MessageVars]
89
90
91 class RoomVars(TypedDict):
92 """
93 Represents a room containing events to include in the email.
94
95 title: a human-readable name for the room
96 hash: a hash of the ID of the room
97 invite: a boolean, which is `True` if the room is an invite the user hasn't
98 accepted yet, `False` otherwise
99 notifs: a list of events, or an empty list if `invite` is `True`.
100 link: a `matrix.to` link to the room
101 avator_url: url to the room's avator
102 """
103
104 title: Optional[str]
105 hash: int
106 invite: bool
107 notifs: List[NotifVars]
108 link: str
109 avatar_url: Optional[str]
110
111
112 class TemplateVars(TypedDict, total=False):
113 """
114 Generic structure for passing to the email sender, can hold all the fields used in email templates.
115
116 app_name: name of the app/service this homeserver is associated with
117 server_name: name of our own homeserver
118 link: a link to include into the email to be sent
119 user_display_name: the display name for the user receiving the notification
120 unsubscribe_link: the link users can click to unsubscribe from email notifications
121 summary_text: a summary of the notification(s). The text used can be customised
122 by configuring the various settings in the `email.subjects` section of the
123 configuration file.
124 rooms: a list of rooms containing events to include in the email
125 reason: information on the event that triggered the email to be sent
126 """
127
128 app_name: str
129 server_name: str
130 link: str
131 user_display_name: str
132 unsubscribe_link: str
133 summary_text: str
134 rooms: List[RoomVars]
135 reason: EmailReason
8585 # We enforce that we have a `cryptography` version that bundles an `openssl`
8686 # with the latest security patches.
8787 "cryptography>=3.4.7",
88 "ijson>=3.0",
88 "ijson>=3.1",
8989 ]
9090
9191 CONDITIONAL_REQUIREMENTS = {
4545 is_guest,
4646 is_appservice_ghost,
4747 should_issue_refresh_token,
48 auth_provider_id,
49 auth_provider_session_id,
4850 ):
4951 """
5052 Args:
6264 "is_guest": is_guest,
6365 "is_appservice_ghost": is_appservice_ghost,
6466 "should_issue_refresh_token": should_issue_refresh_token,
67 "auth_provider_id": auth_provider_id,
68 "auth_provider_session_id": auth_provider_session_id,
6569 }
6670
6771 async def _handle_request(self, request, user_id):
7276 is_guest = content["is_guest"]
7377 is_appservice_ghost = content["is_appservice_ghost"]
7478 should_issue_refresh_token = content["should_issue_refresh_token"]
79 auth_provider_id = content["auth_provider_id"]
80 auth_provider_session_id = content["auth_provider_session_id"]
7581
7682 res = await self.registration_handler.register_device_inner(
7783 user_id,
8086 is_guest,
8187 is_appservice_ghost=is_appservice_ghost,
8288 should_issue_refresh_token=should_issue_refresh_token,
89 auth_provider_id=auth_provider_id,
90 auth_provider_session_id=auth_provider_session_id,
8391 )
8492
8593 return 200, res
1313 from typing import List, Optional, Tuple
1414
1515 from synapse.storage.database import LoggingDatabaseConnection
16 from synapse.storage.util.id_generators import _load_current_id
16 from synapse.storage.util.id_generators import AbstractStreamIdTracker, _load_current_id
1717
1818
19 class SlavedIdTracker:
19 class SlavedIdTracker(AbstractStreamIdTracker):
20 """Tracks the "current" stream ID of a stream with a single writer.
21
22 See `AbstractStreamIdTracker` for more details.
23
24 Note that this class does not work correctly when there are multiple
25 writers.
26 """
27
2028 def __init__(
2129 self,
2230 db_conn: LoggingDatabaseConnection,
3543 self._current = (max if self.step > 0 else min)(self._current, new_id)
3644
3745 def get_current_token(self) -> int:
38 """
39
40 Returns:
41 int
42 """
4346 return self._current
4447
4548 def get_current_token_for_writer(self, instance_name: str) -> int:
46 """Returns the position of the given writer.
47
48 For streams with single writers this is equivalent to
49 `get_current_token`.
50 """
5149 return self.get_current_token()
1212 # See the License for the specific language governing permissions and
1313 # limitations under the License.
1414
15 from synapse.replication.slave.storage._slaved_id_tracker import SlavedIdTracker
1615 from synapse.replication.tcp.streams import PushRulesStream
1716 from synapse.storage.databases.main.push_rule import PushRulesWorkerStore
1817
2423 return self._push_rules_stream_id_gen.get_current_token()
2524
2625 def process_replication_rows(self, stream_name, instance_name, token, rows):
27 # We assert this for the benefit of mypy
28 assert isinstance(self._push_rules_stream_id_gen, SlavedIdTracker)
29
3026 if stream_name == PushRulesStream.NAME:
3127 self._push_rules_stream_id_gen.advance(instance_name, token)
3228 for row in rows:
1313 # limitations under the License.
1414 import heapq
1515 from collections.abc import Iterable
16 from typing import TYPE_CHECKING, List, Optional, Tuple, Type
16 from typing import TYPE_CHECKING, Optional, Tuple, Type
1717
1818 import attr
1919
156156
157157 # now we fetch up to that many rows from the events table
158158
159 event_rows: List[Tuple] = await self._store.get_all_new_forward_event_rows(
159 event_rows = await self._store.get_all_new_forward_event_rows(
160160 instance_name, from_token, current_token, target_row_count
161161 )
162162
190190 # finally, fetch the ex-outliers rows. We assume there are few enough of these
191191 # not to bother with the limit.
192192
193 ex_outliers_rows: List[Tuple] = await self._store.get_ex_outlier_stream_rows(
193 ex_outliers_rows = await self._store.get_ex_outlier_stream_rows(
194194 instance_name, from_token, upper_limit
195195 )
196196
1616
1717 import logging
1818 import platform
19 from http import HTTPStatus
1920 from typing import TYPE_CHECKING, Optional, Tuple
2021
2122 import synapse
3738 from synapse.rest.admin.event_reports import (
3839 EventReportDetailRestServlet,
3940 EventReportsRestServlet,
41 )
42 from synapse.rest.admin.federation import (
43 DestinationsRestServlet,
44 ListDestinationsRestServlet,
4045 )
4146 from synapse.rest.admin.groups import DeleteGroupAdminRestServlet
4247 from synapse.rest.admin.media import ListMediaInRoom, register_servlets_for_media_repo
97102 }
98103
99104 def on_GET(self, request: SynapseRequest) -> Tuple[int, JsonDict]:
100 return 200, self.res
105 return HTTPStatus.OK, self.res
101106
102107
103108 class PurgeHistoryRestServlet(RestServlet):
129134 event = await self.store.get_event(event_id)
130135
131136 if event.room_id != room_id:
132 raise SynapseError(400, "Event is for wrong room.")
137 raise SynapseError(HTTPStatus.BAD_REQUEST, "Event is for wrong room.")
133138
134139 # RoomStreamToken expects [int] not Optional[int]
135140 assert event.internal_metadata.stream_ordering is not None
143148 ts = body["purge_up_to_ts"]
144149 if not isinstance(ts, int):
145150 raise SynapseError(
146 400, "purge_up_to_ts must be an int", errcode=Codes.BAD_JSON
151 HTTPStatus.BAD_REQUEST,
152 "purge_up_to_ts must be an int",
153 errcode=Codes.BAD_JSON,
147154 )
148155
149156 stream_ordering = await self.store.find_first_stream_ordering_after_ts(ts)
159166 stream_ordering,
160167 )
161168 raise SynapseError(
162 404, "there is no event to be purged", errcode=Codes.NOT_FOUND
169 HTTPStatus.NOT_FOUND,
170 "there is no event to be purged",
171 errcode=Codes.NOT_FOUND,
163172 )
164173 (stream, topo, _event_id) = r
165174 token = "t%d-%d" % (topo, stream)
172181 )
173182 else:
174183 raise SynapseError(
175 400,
184 HTTPStatus.BAD_REQUEST,
176185 "must specify purge_up_to_event_id or purge_up_to_ts",
177186 errcode=Codes.BAD_JSON,
178187 )
181190 room_id, token, delete_local_events=delete_local_events
182191 )
183192
184 return 200, {"purge_id": purge_id}
193 return HTTPStatus.OK, {"purge_id": purge_id}
185194
186195
187196 class PurgeHistoryStatusRestServlet(RestServlet):
200209 if purge_status is None:
201210 raise NotFoundError("purge id '%s' not found" % purge_id)
202211
203 return 200, purge_status.asdict()
212 return HTTPStatus.OK, purge_status.asdict()
204213
205214
206215 ########################################################################################
255264 ListRegistrationTokensRestServlet(hs).register(http_server)
256265 NewRegistrationTokenRestServlet(hs).register(http_server)
257266 RegistrationTokenRestServlet(hs).register(http_server)
267 DestinationsRestServlet(hs).register(http_server)
268 ListDestinationsRestServlet(hs).register(http_server)
258269
259270 # Some servlets only get registered for the main process.
260271 if hs.config.worker.worker_app is None:
1212 # limitations under the License.
1313
1414 import re
15 from http import HTTPStatus
1516 from typing import Iterable, Pattern
1617
1718 from synapse.api.auth import Auth
6162 """
6263 is_admin = await auth.is_server_admin(user_id)
6364 if not is_admin:
64 raise AuthError(403, "You are not a server admin")
65 raise AuthError(HTTPStatus.FORBIDDEN, "You are not a server admin")
1111 # See the License for the specific language governing permissions and
1212 # limitations under the License.
1313 import logging
14 from http import HTTPStatus
1415 from typing import TYPE_CHECKING, Tuple
1516
1617 from synapse.api.errors import NotFoundError, SynapseError
5253
5354 target_user = UserID.from_string(user_id)
5455 if not self.hs.is_mine(target_user):
55 raise SynapseError(400, "Can only lookup local users")
56 raise SynapseError(HTTPStatus.BAD_REQUEST, "Can only lookup local users")
5657
5758 u = await self.store.get_user_by_id(target_user.to_string())
5859 if u is None:
6162 device = await self.device_handler.get_device(
6263 target_user.to_string(), device_id
6364 )
64 return 200, device
65 return HTTPStatus.OK, device
6566
6667 async def on_DELETE(
6768 self, request: SynapseRequest, user_id: str, device_id: str
7071
7172 target_user = UserID.from_string(user_id)
7273 if not self.hs.is_mine(target_user):
73 raise SynapseError(400, "Can only lookup local users")
74 raise SynapseError(HTTPStatus.BAD_REQUEST, "Can only lookup local users")
7475
7576 u = await self.store.get_user_by_id(target_user.to_string())
7677 if u is None:
7778 raise NotFoundError("Unknown user")
7879
7980 await self.device_handler.delete_device(target_user.to_string(), device_id)
80 return 200, {}
81 return HTTPStatus.OK, {}
8182
8283 async def on_PUT(
8384 self, request: SynapseRequest, user_id: str, device_id: str
8687
8788 target_user = UserID.from_string(user_id)
8889 if not self.hs.is_mine(target_user):
89 raise SynapseError(400, "Can only lookup local users")
90 raise SynapseError(HTTPStatus.BAD_REQUEST, "Can only lookup local users")
9091
9192 u = await self.store.get_user_by_id(target_user.to_string())
9293 if u is None:
9697 await self.device_handler.update_device(
9798 target_user.to_string(), device_id, body
9899 )
99 return 200, {}
100 return HTTPStatus.OK, {}
100101
101102
102103 class DevicesRestServlet(RestServlet):
123124
124125 target_user = UserID.from_string(user_id)
125126 if not self.hs.is_mine(target_user):
126 raise SynapseError(400, "Can only lookup local users")
127 raise SynapseError(HTTPStatus.BAD_REQUEST, "Can only lookup local users")
127128
128129 u = await self.store.get_user_by_id(target_user.to_string())
129130 if u is None:
130131 raise NotFoundError("Unknown user")
131132
132133 devices = await self.device_handler.get_devices_by_user(target_user.to_string())
133 return 200, {"devices": devices, "total": len(devices)}
134 return HTTPStatus.OK, {"devices": devices, "total": len(devices)}
134135
135136
136137 class DeleteDevicesRestServlet(RestServlet):
154155
155156 target_user = UserID.from_string(user_id)
156157 if not self.hs.is_mine(target_user):
157 raise SynapseError(400, "Can only lookup local users")
158 raise SynapseError(HTTPStatus.BAD_REQUEST, "Can only lookup local users")
158159
159160 u = await self.store.get_user_by_id(target_user.to_string())
160161 if u is None:
166167 await self.device_handler.delete_devices(
167168 target_user.to_string(), body["devices"]
168169 )
169 return 200, {}
170 return HTTPStatus.OK, {}
1212 # limitations under the License.
1313
1414 import logging
15 from http import HTTPStatus
1516 from typing import TYPE_CHECKING, Tuple
1617
1718 from synapse.api.errors import Codes, NotFoundError, SynapseError
6566
6667 if start < 0:
6768 raise SynapseError(
68 400,
69 HTTPStatus.BAD_REQUEST,
6970 "The start parameter must be a positive integer.",
7071 errcode=Codes.INVALID_PARAM,
7172 )
7273
7374 if limit < 0:
7475 raise SynapseError(
75 400,
76 HTTPStatus.BAD_REQUEST,
7677 "The limit parameter must be a positive integer.",
7778 errcode=Codes.INVALID_PARAM,
7879 )
7980
8081 if direction not in ("f", "b"):
8182 raise SynapseError(
82 400, "Unknown direction: %s" % (direction,), errcode=Codes.INVALID_PARAM
83 HTTPStatus.BAD_REQUEST,
84 "Unknown direction: %s" % (direction,),
85 errcode=Codes.INVALID_PARAM,
8386 )
8487
8588 event_reports, total = await self.store.get_event_reports_paginate(
8992 if (start + limit) < total:
9093 ret["next_token"] = start + len(event_reports)
9194
92 return 200, ret
95 return HTTPStatus.OK, ret
9396
9497
9598 class EventReportDetailRestServlet(RestServlet):
126129 try:
127130 resolved_report_id = int(report_id)
128131 except ValueError:
129 raise SynapseError(400, message, errcode=Codes.INVALID_PARAM)
132 raise SynapseError(
133 HTTPStatus.BAD_REQUEST, message, errcode=Codes.INVALID_PARAM
134 )
130135
131136 if resolved_report_id < 0:
132 raise SynapseError(400, message, errcode=Codes.INVALID_PARAM)
137 raise SynapseError(
138 HTTPStatus.BAD_REQUEST, message, errcode=Codes.INVALID_PARAM
139 )
133140
134141 ret = await self.store.get_event_report(resolved_report_id)
135142 if not ret:
136143 raise NotFoundError("Event report not found")
137144
138 return 200, ret
145 return HTTPStatus.OK, ret
0 # Copyright 2021 The Matrix.org Foundation C.I.C.
1 #
2 # Licensed under the Apache License, Version 2.0 (the "License");
3 # you may not use this file except in compliance with the License.
4 # You may obtain a copy of the License at
5 #
6 # http://www.apache.org/licenses/LICENSE-2.0
7 #
8 # Unless required by applicable law or agreed to in writing, software
9 # distributed under the License is distributed on an "AS IS" BASIS,
10 # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
11 # See the License for the specific language governing permissions and
12 # limitations under the License.
13 import logging
14 from http import HTTPStatus
15 from typing import TYPE_CHECKING, Tuple
16
17 from synapse.api.errors import Codes, NotFoundError, SynapseError
18 from synapse.http.servlet import RestServlet, parse_integer, parse_string
19 from synapse.http.site import SynapseRequest
20 from synapse.rest.admin._base import admin_patterns, assert_requester_is_admin
21 from synapse.storage.databases.main.transactions import DestinationSortOrder
22 from synapse.types import JsonDict
23
24 if TYPE_CHECKING:
25 from synapse.server import HomeServer
26
27 logger = logging.getLogger(__name__)
28
29
30 class ListDestinationsRestServlet(RestServlet):
31 """Get request to list all destinations.
32 This needs user to have administrator access in Synapse.
33
34 GET /_synapse/admin/v1/federation/destinations?from=0&limit=10
35
36 returns:
37 200 OK with list of destinations if success otherwise an error.
38
39 The parameters `from` and `limit` are required only for pagination.
40 By default, a `limit` of 100 is used.
41 The parameter `destination` can be used to filter by destination.
42 The parameter `order_by` can be used to order the result.
43 """
44
45 PATTERNS = admin_patterns("/federation/destinations$")
46
47 def __init__(self, hs: "HomeServer"):
48 self._auth = hs.get_auth()
49 self._store = hs.get_datastore()
50
51 async def on_GET(self, request: SynapseRequest) -> Tuple[int, JsonDict]:
52 await assert_requester_is_admin(self._auth, request)
53
54 start = parse_integer(request, "from", default=0)
55 limit = parse_integer(request, "limit", default=100)
56
57 if start < 0:
58 raise SynapseError(
59 HTTPStatus.BAD_REQUEST,
60 "Query parameter from must be a string representing a positive integer.",
61 errcode=Codes.INVALID_PARAM,
62 )
63
64 if limit < 0:
65 raise SynapseError(
66 HTTPStatus.BAD_REQUEST,
67 "Query parameter limit must be a string representing a positive integer.",
68 errcode=Codes.INVALID_PARAM,
69 )
70
71 destination = parse_string(request, "destination")
72
73 order_by = parse_string(
74 request,
75 "order_by",
76 default=DestinationSortOrder.DESTINATION.value,
77 allowed_values=[dest.value for dest in DestinationSortOrder],
78 )
79
80 direction = parse_string(request, "dir", default="f", allowed_values=("f", "b"))
81
82 destinations, total = await self._store.get_destinations_paginate(
83 start, limit, destination, order_by, direction
84 )
85 response = {"destinations": destinations, "total": total}
86 if (start + limit) < total:
87 response["next_token"] = str(start + len(destinations))
88
89 return HTTPStatus.OK, response
90
91
92 class DestinationsRestServlet(RestServlet):
93 """Get details of a destination.
94 This needs user to have administrator access in Synapse.
95
96 GET /_synapse/admin/v1/federation/destinations/<destination>
97
98 returns:
99 200 OK with details of a destination if success otherwise an error.
100 """
101
102 PATTERNS = admin_patterns("/federation/destinations/(?P<destination>[^/]+)$")
103
104 def __init__(self, hs: "HomeServer"):
105 self._auth = hs.get_auth()
106 self._store = hs.get_datastore()
107
108 async def on_GET(
109 self, request: SynapseRequest, destination: str
110 ) -> Tuple[int, JsonDict]:
111 await assert_requester_is_admin(self._auth, request)
112
113 destination_retry_timings = await self._store.get_destination_retry_timings(
114 destination
115 )
116
117 if not destination_retry_timings:
118 raise NotFoundError("Unknown destination")
119
120 last_successful_stream_ordering = (
121 await self._store.get_destination_last_successful_stream_ordering(
122 destination
123 )
124 )
125
126 response = {
127 "destination": destination,
128 "failure_ts": destination_retry_timings.failure_ts,
129 "retry_last_ts": destination_retry_timings.retry_last_ts,
130 "retry_interval": destination_retry_timings.retry_interval,
131 "last_successful_stream_ordering": last_successful_stream_ordering,
132 }
133
134 return HTTPStatus.OK, response
1111 # See the License for the specific language governing permissions and
1212 # limitations under the License.
1313 import logging
14 from http import HTTPStatus
1415 from typing import TYPE_CHECKING, Tuple
1516
1617 from synapse.api.errors import SynapseError
4243 await assert_user_is_admin(self.auth, requester.user)
4344
4445 if not self.is_mine_id(group_id):
45 raise SynapseError(400, "Can only delete local groups")
46 raise SynapseError(HTTPStatus.BAD_REQUEST, "Can only delete local groups")
4647
4748 await self.group_server.delete_group(group_id, requester.user.to_string())
48 return 200, {}
49 return HTTPStatus.OK, {}
1313 # limitations under the License.
1414
1515 import logging
16 from http import HTTPStatus
1617 from typing import TYPE_CHECKING, Tuple
1718
1819 from synapse.api.errors import AuthError, Codes, NotFoundError, SynapseError
6162 room_id, requester.user.to_string()
6263 )
6364
64 return 200, {"num_quarantined": num_quarantined}
65 return HTTPStatus.OK, {"num_quarantined": num_quarantined}
6566
6667
6768 class QuarantineMediaByUser(RestServlet):
8889 user_id, requester.user.to_string()
8990 )
9091
91 return 200, {"num_quarantined": num_quarantined}
92 return HTTPStatus.OK, {"num_quarantined": num_quarantined}
9293
9394
9495 class QuarantineMediaByID(RestServlet):
117118 server_name, media_id, requester.user.to_string()
118119 )
119120
120 return 200, {}
121 return HTTPStatus.OK, {}
121122
122123
123124 class UnquarantineMediaByID(RestServlet):
146147 # Remove from quarantine this media id
147148 await self.store.quarantine_media_by_id(server_name, media_id, None)
148149
149 return 200, {}
150 return HTTPStatus.OK, {}
150151
151152
152153 class ProtectMediaByID(RestServlet):
169170 # Protect this media id
170171 await self.store.mark_local_media_as_safe(media_id, safe=True)
171172
172 return 200, {}
173 return HTTPStatus.OK, {}
173174
174175
175176 class UnprotectMediaByID(RestServlet):
192193 # Unprotect this media id
193194 await self.store.mark_local_media_as_safe(media_id, safe=False)
194195
195 return 200, {}
196 return HTTPStatus.OK, {}
196197
197198
198199 class ListMediaInRoom(RestServlet):
210211 requester = await self.auth.get_user_by_req(request)
211212 is_admin = await self.auth.is_server_admin(requester.user)
212213 if not is_admin:
213 raise AuthError(403, "You are not a server admin")
214 raise AuthError(HTTPStatus.FORBIDDEN, "You are not a server admin")
214215
215216 local_mxcs, remote_mxcs = await self.store.get_media_mxcs_in_room(room_id)
216217
217 return 200, {"local": local_mxcs, "remote": remote_mxcs}
218 return HTTPStatus.OK, {"local": local_mxcs, "remote": remote_mxcs}
218219
219220
220221 class PurgeMediaCacheRestServlet(RestServlet):
232233
233234 if before_ts < 0:
234235 raise SynapseError(
235 400,
236 HTTPStatus.BAD_REQUEST,
236237 "Query parameter before_ts must be a positive integer.",
237238 errcode=Codes.INVALID_PARAM,
238239 )
239240 elif before_ts < 30000000000: # Dec 1970 in milliseconds, Aug 2920 in seconds
240241 raise SynapseError(
241 400,
242 HTTPStatus.BAD_REQUEST,
242243 "Query parameter before_ts you provided is from the year 1970. "
243244 + "Double check that you are providing a timestamp in milliseconds.",
244245 errcode=Codes.INVALID_PARAM,
246247
247248 ret = await self.media_repository.delete_old_remote_media(before_ts)
248249
249 return 200, ret
250 return HTTPStatus.OK, ret
250251
251252
252253 class DeleteMediaByID(RestServlet):
266267 await assert_requester_is_admin(self.auth, request)
267268
268269 if self.server_name != server_name:
269 raise SynapseError(400, "Can only delete local media")
270 raise SynapseError(HTTPStatus.BAD_REQUEST, "Can only delete local media")
270271
271272 if await self.store.get_local_media(media_id) is None:
272273 raise NotFoundError("Unknown media")
276277 deleted_media, total = await self.media_repository.delete_local_media_ids(
277278 [media_id]
278279 )
279 return 200, {"deleted_media": deleted_media, "total": total}
280 return HTTPStatus.OK, {"deleted_media": deleted_media, "total": total}
280281
281282
282283 class DeleteMediaByDateSize(RestServlet):
303304
304305 if before_ts < 0:
305306 raise SynapseError(
306 400,
307 HTTPStatus.BAD_REQUEST,
307308 "Query parameter before_ts must be a positive integer.",
308309 errcode=Codes.INVALID_PARAM,
309310 )
310311 elif before_ts < 30000000000: # Dec 1970 in milliseconds, Aug 2920 in seconds
311312 raise SynapseError(
312 400,
313 HTTPStatus.BAD_REQUEST,
313314 "Query parameter before_ts you provided is from the year 1970. "
314315 + "Double check that you are providing a timestamp in milliseconds.",
315316 errcode=Codes.INVALID_PARAM,
316317 )
317318 if size_gt < 0:
318319 raise SynapseError(
319 400,
320 HTTPStatus.BAD_REQUEST,
320321 "Query parameter size_gt must be a string representing a positive integer.",
321322 errcode=Codes.INVALID_PARAM,
322323 )
323324
324325 if self.server_name != server_name:
325 raise SynapseError(400, "Can only delete local media")
326 raise SynapseError(HTTPStatus.BAD_REQUEST, "Can only delete local media")
326327
327328 logging.info(
328329 "Deleting local media by timestamp: %s, size larger than: %s, keep profile media: %s"
332333 deleted_media, total = await self.media_repository.delete_old_local_media(
333334 before_ts, size_gt, keep_profiles
334335 )
335 return 200, {"deleted_media": deleted_media, "total": total}
336 return HTTPStatus.OK, {"deleted_media": deleted_media, "total": total}
336337
337338
338339 class UserMediaRestServlet(RestServlet):
368369 await assert_requester_is_admin(self.auth, request)
369370
370371 if not self.is_mine(UserID.from_string(user_id)):
371 raise SynapseError(400, "Can only look up local users")
372 raise SynapseError(HTTPStatus.BAD_REQUEST, "Can only look up local users")
372373
373374 user = await self.store.get_user_by_id(user_id)
374375 if user is None:
379380
380381 if start < 0:
381382 raise SynapseError(
382 400,
383 HTTPStatus.BAD_REQUEST,
383384 "Query parameter from must be a string representing a positive integer.",
384385 errcode=Codes.INVALID_PARAM,
385386 )
386387
387388 if limit < 0:
388389 raise SynapseError(
389 400,
390 HTTPStatus.BAD_REQUEST,
390391 "Query parameter limit must be a string representing a positive integer.",
391392 errcode=Codes.INVALID_PARAM,
392393 )
424425 if (start + limit) < total:
425426 ret["next_token"] = start + len(media)
426427
427 return 200, ret
428 return HTTPStatus.OK, ret
428429
429430 async def on_DELETE(
430431 self, request: SynapseRequest, user_id: str
435436 await assert_requester_is_admin(self.auth, request)
436437
437438 if not self.is_mine(UserID.from_string(user_id)):
438 raise SynapseError(400, "Can only look up local users")
439 raise SynapseError(HTTPStatus.BAD_REQUEST, "Can only look up local users")
439440
440441 user = await self.store.get_user_by_id(user_id)
441442 if user is None:
446447
447448 if start < 0:
448449 raise SynapseError(
449 400,
450 HTTPStatus.BAD_REQUEST,
450451 "Query parameter from must be a string representing a positive integer.",
451452 errcode=Codes.INVALID_PARAM,
452453 )
453454
454455 if limit < 0:
455456 raise SynapseError(
456 400,
457 HTTPStatus.BAD_REQUEST,
457458 "Query parameter limit must be a string representing a positive integer.",
458459 errcode=Codes.INVALID_PARAM,
459460 )
491492 ([row["media_id"] for row in media])
492493 )
493494
494 return 200, {"deleted_media": deleted_media, "total": total}
495 return HTTPStatus.OK, {"deleted_media": deleted_media, "total": total}
495496
496497
497498 def register_servlets_for_media_repo(hs: "HomeServer", http_server: HttpServer) -> None:
1313
1414 import logging
1515 import string
16 from http import HTTPStatus
1617 from typing import TYPE_CHECKING, Tuple
1718
1819 from synapse.api.errors import Codes, NotFoundError, SynapseError
7677 await assert_requester_is_admin(self.auth, request)
7778 valid = parse_boolean(request, "valid")
7879 token_list = await self.store.get_registration_tokens(valid)
79 return 200, {"registration_tokens": token_list}
80 return HTTPStatus.OK, {"registration_tokens": token_list}
8081
8182
8283 class NewRegistrationTokenRestServlet(RestServlet):
122123 if "token" in body:
123124 token = body["token"]
124125 if not isinstance(token, str):
125 raise SynapseError(400, "token must be a string", Codes.INVALID_PARAM)
126 raise SynapseError(
127 HTTPStatus.BAD_REQUEST,
128 "token must be a string",
129 Codes.INVALID_PARAM,
130 )
126131 if not (0 < len(token) <= 64):
127132 raise SynapseError(
128 400,
133 HTTPStatus.BAD_REQUEST,
129134 "token must not be empty and must not be longer than 64 characters",
130135 Codes.INVALID_PARAM,
131136 )
132137 if not set(token).issubset(self.allowed_chars_set):
133138 raise SynapseError(
134 400,
139 HTTPStatus.BAD_REQUEST,
135140 "token must consist only of characters matched by the regex [A-Za-z0-9-_]",
136141 Codes.INVALID_PARAM,
137142 )
141146 length = body.get("length", 16)
142147 if not isinstance(length, int):
143148 raise SynapseError(
144 400, "length must be an integer", Codes.INVALID_PARAM
149 HTTPStatus.BAD_REQUEST,
150 "length must be an integer",
151 Codes.INVALID_PARAM,
145152 )
146153 if not (0 < length <= 64):
147154 raise SynapseError(
148 400,
155 HTTPStatus.BAD_REQUEST,
149156 "length must be greater than zero and not greater than 64",
150157 Codes.INVALID_PARAM,
151158 )
161168 or (isinstance(uses_allowed, int) and uses_allowed >= 0)
162169 ):
163170 raise SynapseError(
164 400,
171 HTTPStatus.BAD_REQUEST,
165172 "uses_allowed must be a non-negative integer or null",
166173 Codes.INVALID_PARAM,
167174 )
169176 expiry_time = body.get("expiry_time", None)
170177 if not isinstance(expiry_time, (int, type(None))):
171178 raise SynapseError(
172 400, "expiry_time must be an integer or null", Codes.INVALID_PARAM
179 HTTPStatus.BAD_REQUEST,
180 "expiry_time must be an integer or null",
181 Codes.INVALID_PARAM,
173182 )
174183 if isinstance(expiry_time, int) and expiry_time < self.clock.time_msec():
175184 raise SynapseError(
176 400, "expiry_time must not be in the past", Codes.INVALID_PARAM
185 HTTPStatus.BAD_REQUEST,
186 "expiry_time must not be in the past",
187 Codes.INVALID_PARAM,
177188 )
178189
179190 created = await self.store.create_registration_token(
181192 )
182193 if not created:
183194 raise SynapseError(
184 400, f"Token already exists: {token}", Codes.INVALID_PARAM
195 HTTPStatus.BAD_REQUEST,
196 f"Token already exists: {token}",
197 Codes.INVALID_PARAM,
185198 )
186199
187200 resp = {
191204 "completed": 0,
192205 "expiry_time": expiry_time,
193206 }
194 return 200, resp
207 return HTTPStatus.OK, resp
195208
196209
197210 class RegistrationTokenRestServlet(RestServlet):
260273 if token_info is None:
261274 raise NotFoundError(f"No such registration token: {token}")
262275
263 return 200, token_info
276 return HTTPStatus.OK, token_info
264277
265278 async def on_PUT(self, request: SynapseRequest, token: str) -> Tuple[int, JsonDict]:
266279 """Update a registration token."""
276289 or (isinstance(uses_allowed, int) and uses_allowed >= 0)
277290 ):
278291 raise SynapseError(
279 400,
292 HTTPStatus.BAD_REQUEST,
280293 "uses_allowed must be a non-negative integer or null",
281294 Codes.INVALID_PARAM,
282295 )
286299 expiry_time = body["expiry_time"]
287300 if not isinstance(expiry_time, (int, type(None))):
288301 raise SynapseError(
289 400, "expiry_time must be an integer or null", Codes.INVALID_PARAM
302 HTTPStatus.BAD_REQUEST,
303 "expiry_time must be an integer or null",
304 Codes.INVALID_PARAM,
290305 )
291306 if isinstance(expiry_time, int) and expiry_time < self.clock.time_msec():
292307 raise SynapseError(
293 400, "expiry_time must not be in the past", Codes.INVALID_PARAM
308 HTTPStatus.BAD_REQUEST,
309 "expiry_time must not be in the past",
310 Codes.INVALID_PARAM,
294311 )
295312 new_attributes["expiry_time"] = expiry_time
296313
306323 if token_info is None:
307324 raise NotFoundError(f"No such registration token: {token}")
308325
309 return 200, token_info
326 return HTTPStatus.OK, token_info
310327
311328 async def on_DELETE(
312329 self, request: SynapseRequest, token: str
315332 await assert_requester_is_admin(self.auth, request)
316333
317334 if await self.store.delete_registration_token(token):
318 return 200, {}
335 return HTTPStatus.OK, {}
319336
320337 raise NotFoundError(f"No such registration token: {token}")
101101 )
102102
103103 if not RoomID.is_valid(room_id):
104 raise SynapseError(400, "%s is not a legal room ID" % (room_id,))
105
106 if not await self._store.get_room(room_id):
107 raise NotFoundError("Unknown room id %s" % (room_id,))
104 raise SynapseError(
105 HTTPStatus.BAD_REQUEST, "%s is not a legal room ID" % (room_id,)
106 )
108107
109108 delete_id = self._pagination_handler.start_shutdown_and_purge_room(
110109 room_id=room_id,
117116 force_purge=force_purge,
118117 )
119118
120 return 200, {"delete_id": delete_id}
119 return HTTPStatus.OK, {"delete_id": delete_id}
121120
122121
123122 class DeleteRoomStatusByRoomIdRestServlet(RestServlet):
136135 await assert_requester_is_admin(self._auth, request)
137136
138137 if not RoomID.is_valid(room_id):
139 raise SynapseError(400, "%s is not a legal room ID" % (room_id,))
138 raise SynapseError(
139 HTTPStatus.BAD_REQUEST, "%s is not a legal room ID" % (room_id,)
140 )
140141
141142 delete_ids = self._pagination_handler.get_delete_ids_by_room(room_id)
142143 if delete_ids is None:
152153 **delete.asdict(),
153154 }
154155 ]
155 return 200, {"results": cast(JsonDict, response)}
156 return HTTPStatus.OK, {"results": cast(JsonDict, response)}
156157
157158
158159 class DeleteRoomStatusByDeleteIdRestServlet(RestServlet):
174175 if delete_status is None:
175176 raise NotFoundError("delete id '%s' not found" % delete_id)
176177
177 return 200, cast(JsonDict, delete_status.asdict())
178 return HTTPStatus.OK, cast(JsonDict, delete_status.asdict())
178179
179180
180181 class ListRoomRestServlet(RestServlet):
216217 RoomSortOrder.STATE_EVENTS.value,
217218 ):
218219 raise SynapseError(
219 400,
220 HTTPStatus.BAD_REQUEST,
220221 "Unknown value for order_by: %s" % (order_by,),
221222 errcode=Codes.INVALID_PARAM,
222223 )
224225 search_term = parse_string(request, "search_term", encoding="utf-8")
225226 if search_term == "":
226227 raise SynapseError(
227 400,
228 HTTPStatus.BAD_REQUEST,
228229 "search_term cannot be an empty string",
229230 errcode=Codes.INVALID_PARAM,
230231 )
232233 direction = parse_string(request, "dir", default="f")
233234 if direction not in ("f", "b"):
234235 raise SynapseError(
235 400, "Unknown direction: %s" % (direction,), errcode=Codes.INVALID_PARAM
236 HTTPStatus.BAD_REQUEST,
237 "Unknown direction: %s" % (direction,),
238 errcode=Codes.INVALID_PARAM,
236239 )
237240
238241 reverse_order = True if direction == "b" else False
264267 else:
265268 response["prev_batch"] = 0
266269
267 return 200, response
270 return HTTPStatus.OK, response
268271
269272
270273 class RoomRestServlet(RestServlet):
309312 members = await self.store.get_users_in_room(room_id)
310313 ret["joined_local_devices"] = await self.store.count_devices_by_users(members)
311314
312 return 200, ret
315 return HTTPStatus.OK, ret
313316
314317 async def on_DELETE(
315318 self, request: SynapseRequest, room_id: str
385388 # See https://github.com/python/mypy/issues/4976#issuecomment-579883622
386389 # for some discussion on why this is necessary. Either way,
387390 # `ret` is an opaque dictionary blob as far as the rest of the app cares.
388 return 200, cast(JsonDict, ret)
391 return HTTPStatus.OK, cast(JsonDict, ret)
389392
390393
391394 class RoomMembersRestServlet(RestServlet):
412415 members = await self.store.get_users_in_room(room_id)
413416 ret = {"members": members, "total": len(members)}
414417
415 return 200, ret
418 return HTTPStatus.OK, ret
416419
417420
418421 class RoomStateRestServlet(RestServlet):
442445 event_ids = await self.store.get_current_state_ids(room_id)
443446 events = await self.store.get_events(event_ids.values())
444447 now = self.clock.time_msec()
445 room_state = await self._event_serializer.serialize_events(
446 events.values(),
447 now,
448 # We don't bother bundling aggregations in when asked for state
449 # events, as clients won't use them.
450 bundle_relations=False,
451 )
448 room_state = await self._event_serializer.serialize_events(events.values(), now)
452449 ret = {"state": room_state}
453450
454 return 200, ret
451 return HTTPStatus.OK, ret
455452
456453
457454 class JoinRoomAliasServlet(ResolveRoomIdMixin, RestServlet):
480477 target_user = UserID.from_string(content["user_id"])
481478
482479 if not self.hs.is_mine(target_user):
483 raise SynapseError(400, "This endpoint can only be used with local users")
480 raise SynapseError(
481 HTTPStatus.BAD_REQUEST,
482 "This endpoint can only be used with local users",
483 )
484484
485485 if not await self.admin_handler.get_user(target_user):
486486 raise NotFoundError("User not found")
526526 ratelimit=False,
527527 )
528528
529 return 200, {"room_id": room_id}
529 return HTTPStatus.OK, {"room_id": room_id}
530530
531531
532532 class MakeRoomAdminRestServlet(ResolveRoomIdMixin, RestServlet):
567567 # Figure out which local users currently have power in the room, if any.
568568 room_state = await self.state_handler.get_current_state(room_id)
569569 if not room_state:
570 raise SynapseError(400, "Server not in room")
570 raise SynapseError(HTTPStatus.BAD_REQUEST, "Server not in room")
571571
572572 create_event = room_state[(EventTypes.Create, "")]
573573 power_levels = room_state.get((EventTypes.PowerLevels, ""))
581581 admin_users.sort(key=lambda user: user_power[user])
582582
583583 if not admin_users:
584 raise SynapseError(400, "No local admin user in room")
584 raise SynapseError(
585 HTTPStatus.BAD_REQUEST, "No local admin user in room"
586 )
585587
586588 admin_user_id = None
587589
598600
599601 if not admin_user_id:
600602 raise SynapseError(
601 400,
603 HTTPStatus.BAD_REQUEST,
602604 "No local admin user in room",
603605 )
604606
609611 admin_user_id = create_event.sender
610612 if not self.is_mine_id(admin_user_id):
611613 raise SynapseError(
612 400,
614 HTTPStatus.BAD_REQUEST,
613615 "No local admin user in room",
614616 )
615617
638640 except AuthError:
639641 # The admin user we found turned out not to have enough power.
640642 raise SynapseError(
641 400, "No local admin user in room with power to update power levels."
643 HTTPStatus.BAD_REQUEST,
644 "No local admin user in room with power to update power levels.",
642645 )
643646
644647 # Now we check if the user we're granting admin rights to is already in
652655 )
653656
654657 if is_joined:
655 return 200, {}
658 return HTTPStatus.OK, {}
656659
657660 join_rules = room_state.get((EventTypes.JoinRules, ""))
658661 is_public = False
660663 is_public = join_rules.content.get("join_rule") == JoinRules.PUBLIC
661664
662665 if is_public:
663 return 200, {}
666 return HTTPStatus.OK, {}
664667
665668 await self.room_member_handler.update_membership(
666669 fake_requester,
669672 action=Membership.INVITE,
670673 )
671674
672 return 200, {}
675 return HTTPStatus.OK, {}
673676
674677
675678 class ForwardExtremitiesRestServlet(ResolveRoomIdMixin, RestServlet):
701704 room_id, _ = await self.resolve_room_id(room_identifier)
702705
703706 deleted_count = await self.store.delete_forward_extremities_for_room(room_id)
704 return 200, {"deleted": deleted_count}
707 return HTTPStatus.OK, {"deleted": deleted_count}
705708
706709 async def on_GET(
707710 self, request: SynapseRequest, room_identifier: str
712715 room_id, _ = await self.resolve_room_id(room_identifier)
713716
714717 extremities = await self.store.get_forward_extremities_for_room(room_id)
715 return 200, {"count": len(extremities), "results": extremities}
718 return HTTPStatus.OK, {"count": len(extremities), "results": extremities}
716719
717720
718721 class RoomEventContextServlet(RestServlet):
761764 )
762765
763766 if not results:
764 raise SynapseError(404, "Event not found.", errcode=Codes.NOT_FOUND)
767 raise SynapseError(
768 HTTPStatus.NOT_FOUND, "Event not found.", errcode=Codes.NOT_FOUND
769 )
765770
766771 time_now = self.clock.time_msec()
767772 results["events_before"] = await self._event_serializer.serialize_events(
774779 results["events_after"], time_now
775780 )
776781 results["state"] = await self._event_serializer.serialize_events(
777 results["state"],
778 time_now,
779 # No need to bundle aggregations for state events
780 bundle_relations=False,
781 )
782
783 return 200, results
782 results["state"], time_now
783 )
784
785 return HTTPStatus.OK, results
784786
785787
786788 class BlockRoomRestServlet(RestServlet):
1010 # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
1111 # See the License for the specific language governing permissions and
1212 # limitations under the License.
13 from http import HTTPStatus
1314 from typing import TYPE_CHECKING, Awaitable, Optional, Tuple
1415
1516 from synapse.api.constants import EventTypes
8182 # but worker processes still need to initialise SendServerNoticeServlet (as it is part of the
8283 # admin api).
8384 if not self.server_notices_manager.is_enabled():
84 raise SynapseError(400, "Server notices are not enabled on this server")
85 raise SynapseError(
86 HTTPStatus.BAD_REQUEST, "Server notices are not enabled on this server"
87 )
8588
8689 target_user = UserID.from_string(body["user_id"])
8790 if not self.hs.is_mine(target_user):
88 raise SynapseError(400, "Server notices can only be sent to local users")
91 raise SynapseError(
92 HTTPStatus.BAD_REQUEST, "Server notices can only be sent to local users"
93 )
8994
9095 if not await self.admin_handler.get_user(target_user):
9196 raise NotFoundError("User not found")
98103 txn_id=txn_id,
99104 )
100105
101 return 200, {"event_id": event.event_id}
106 return HTTPStatus.OK, {"event_id": event.event_id}
102107
103108 def on_PUT(
104109 self, request: SynapseRequest, txn_id: str
1212 # limitations under the License.
1313
1414 import logging
15 from http import HTTPStatus
1516 from typing import TYPE_CHECKING, Tuple
1617
1718 from synapse.api.errors import Codes, SynapseError
5253 UserSortOrder.DISPLAYNAME.value,
5354 ):
5455 raise SynapseError(
55 400,
56 HTTPStatus.BAD_REQUEST,
5657 "Unknown value for order_by: %s" % (order_by,),
5758 errcode=Codes.INVALID_PARAM,
5859 )
6061 start = parse_integer(request, "from", default=0)
6162 if start < 0:
6263 raise SynapseError(
63 400,
64 HTTPStatus.BAD_REQUEST,
6465 "Query parameter from must be a string representing a positive integer.",
6566 errcode=Codes.INVALID_PARAM,
6667 )
6869 limit = parse_integer(request, "limit", default=100)
6970 if limit < 0:
7071 raise SynapseError(
71 400,
72 HTTPStatus.BAD_REQUEST,
7273 "Query parameter limit must be a string representing a positive integer.",
7374 errcode=Codes.INVALID_PARAM,
7475 )
7677 from_ts = parse_integer(request, "from_ts", default=0)
7778 if from_ts < 0:
7879 raise SynapseError(
79 400,
80 HTTPStatus.BAD_REQUEST,
8081 "Query parameter from_ts must be a string representing a positive integer.",
8182 errcode=Codes.INVALID_PARAM,
8283 )
8586 if until_ts is not None:
8687 if until_ts < 0:
8788 raise SynapseError(
88 400,
89 HTTPStatus.BAD_REQUEST,
8990 "Query parameter until_ts must be a string representing a positive integer.",
9091 errcode=Codes.INVALID_PARAM,
9192 )
9293 if until_ts <= from_ts:
9394 raise SynapseError(
94 400,
95 HTTPStatus.BAD_REQUEST,
9596 "Query parameter until_ts must be greater than from_ts.",
9697 errcode=Codes.INVALID_PARAM,
9798 )
99100 search_term = parse_string(request, "search_term")
100101 if search_term == "":
101102 raise SynapseError(
102 400,
103 HTTPStatus.BAD_REQUEST,
103104 "Query parameter search_term cannot be an empty string.",
104105 errcode=Codes.INVALID_PARAM,
105106 )
107108 direction = parse_string(request, "dir", default="f")
108109 if direction not in ("f", "b"):
109110 raise SynapseError(
110 400, "Unknown direction: %s" % (direction,), errcode=Codes.INVALID_PARAM
111 HTTPStatus.BAD_REQUEST,
112 "Unknown direction: %s" % (direction,),
113 errcode=Codes.INVALID_PARAM,
111114 )
112115
113116 users_media, total = await self.store.get_users_media_usage_paginate(
117120 if (start + limit) < total:
118121 ret["next_token"] = start + len(users_media)
119122
120 return 200, ret
123 return HTTPStatus.OK, ret
7878
7979 if start < 0:
8080 raise SynapseError(
81 400,
81 HTTPStatus.BAD_REQUEST,
8282 "Query parameter from must be a string representing a positive integer.",
8383 errcode=Codes.INVALID_PARAM,
8484 )
8585
8686 if limit < 0:
8787 raise SynapseError(
88 400,
88 HTTPStatus.BAD_REQUEST,
8989 "Query parameter limit must be a string representing a positive integer.",
9090 errcode=Codes.INVALID_PARAM,
9191 )
121121 if (start + limit) < total:
122122 ret["next_token"] = str(start + len(users))
123123
124 return 200, ret
124 return HTTPStatus.OK, ret
125125
126126
127127 class UserRestServletV2(RestServlet):
171171
172172 target_user = UserID.from_string(user_id)
173173 if not self.hs.is_mine(target_user):
174 raise SynapseError(400, "Can only look up local users")
174 raise SynapseError(HTTPStatus.BAD_REQUEST, "Can only look up local users")
175175
176176 ret = await self.admin_handler.get_user(target_user)
177177
178178 if not ret:
179179 raise NotFoundError("User not found")
180180
181 return 200, ret
181 return HTTPStatus.OK, ret
182182
183183 async def on_PUT(
184184 self, request: SynapseRequest, user_id: str
190190 body = parse_json_object_from_request(request)
191191
192192 if not self.hs.is_mine(target_user):
193 raise SynapseError(400, "This endpoint can only be used with local users")
193 raise SynapseError(
194 HTTPStatus.BAD_REQUEST,
195 "This endpoint can only be used with local users",
196 )
194197
195198 user = await self.admin_handler.get_user(target_user)
196199 user_id = target_user.to_string()
209212
210213 user_type = body.get("user_type", None)
211214 if user_type is not None and user_type not in UserTypes.ALL_USER_TYPES:
212 raise SynapseError(400, "Invalid user type")
215 raise SynapseError(HTTPStatus.BAD_REQUEST, "Invalid user type")
213216
214217 set_admin_to = body.get("admin", False)
215218 if not isinstance(set_admin_to, bool):
222225 password = body.get("password", None)
223226 if password is not None:
224227 if not isinstance(password, str) or len(password) > 512:
225 raise SynapseError(400, "Invalid password")
228 raise SynapseError(HTTPStatus.BAD_REQUEST, "Invalid password")
226229
227230 deactivate = body.get("deactivated", False)
228231 if not isinstance(deactivate, bool):
229 raise SynapseError(400, "'deactivated' parameter is not of type boolean")
232 raise SynapseError(
233 HTTPStatus.BAD_REQUEST, "'deactivated' parameter is not of type boolean"
234 )
230235
231236 # convert List[Dict[str, str]] into List[Tuple[str, str]]
232237 if external_ids is not None:
281286 user_id,
282287 )
283288 except ExternalIDReuseException:
284 raise SynapseError(409, "External id is already in use.")
289 raise SynapseError(
290 HTTPStatus.CONFLICT, "External id is already in use."
291 )
285292
286293 if "avatar_url" in body and isinstance(body["avatar_url"], str):
287294 await self.profile_handler.set_avatar_url(
292299 if set_admin_to != user["admin"]:
293300 auth_user = requester.user
294301 if target_user == auth_user and not set_admin_to:
295 raise SynapseError(400, "You may not demote yourself.")
302 raise SynapseError(
303 HTTPStatus.BAD_REQUEST, "You may not demote yourself."
304 )
296305
297306 await self.store.set_server_admin(target_user, set_admin_to)
298307
318327 and self.auth_handler.can_change_password()
319328 ):
320329 raise SynapseError(
321 400, "Must provide a password to re-activate an account."
330 HTTPStatus.BAD_REQUEST,
331 "Must provide a password to re-activate an account.",
322332 )
323333
324334 await self.deactivate_account_handler.activate_account(
331341 user = await self.admin_handler.get_user(target_user)
332342 assert user is not None
333343
334 return 200, user
344 return HTTPStatus.OK, user
335345
336346 else: # create user
337347 displayname = body.get("displayname", None)
380390 user_id,
381391 )
382392 except ExternalIDReuseException:
383 raise SynapseError(409, "External id is already in use.")
393 raise SynapseError(
394 HTTPStatus.CONFLICT, "External id is already in use."
395 )
384396
385397 if "avatar_url" in body and isinstance(body["avatar_url"], str):
386398 await self.profile_handler.set_avatar_url(
428440
429441 nonce = secrets.token_hex(64)
430442 self.nonces[nonce] = int(self.reactor.seconds())
431 return 200, {"nonce": nonce}
443 return HTTPStatus.OK, {"nonce": nonce}
432444
433445 async def on_POST(self, request: SynapseRequest) -> Tuple[int, JsonDict]:
434446 self._clear_old_nonces()
435447
436448 if not self.hs.config.registration.registration_shared_secret:
437 raise SynapseError(400, "Shared secret registration is not enabled")
449 raise SynapseError(
450 HTTPStatus.BAD_REQUEST, "Shared secret registration is not enabled"
451 )
438452
439453 body = parse_json_object_from_request(request)
440454
441455 if "nonce" not in body:
442 raise SynapseError(400, "nonce must be specified", errcode=Codes.BAD_JSON)
456 raise SynapseError(
457 HTTPStatus.BAD_REQUEST,
458 "nonce must be specified",
459 errcode=Codes.BAD_JSON,
460 )
443461
444462 nonce = body["nonce"]
445463
446464 if nonce not in self.nonces:
447 raise SynapseError(400, "unrecognised nonce")
465 raise SynapseError(HTTPStatus.BAD_REQUEST, "unrecognised nonce")
448466
449467 # Delete the nonce, so it can't be reused, even if it's invalid
450468 del self.nonces[nonce]
451469
452470 if "username" not in body:
453471 raise SynapseError(
454 400, "username must be specified", errcode=Codes.BAD_JSON
472 HTTPStatus.BAD_REQUEST,
473 "username must be specified",
474 errcode=Codes.BAD_JSON,
455475 )
456476 else:
457477 if not isinstance(body["username"], str) or len(body["username"]) > 512:
458 raise SynapseError(400, "Invalid username")
478 raise SynapseError(HTTPStatus.BAD_REQUEST, "Invalid username")
459479
460480 username = body["username"].encode("utf-8")
461481 if b"\x00" in username:
462 raise SynapseError(400, "Invalid username")
482 raise SynapseError(HTTPStatus.BAD_REQUEST, "Invalid username")
463483
464484 if "password" not in body:
465485 raise SynapseError(
466 400, "password must be specified", errcode=Codes.BAD_JSON
486 HTTPStatus.BAD_REQUEST,
487 "password must be specified",
488 errcode=Codes.BAD_JSON,
467489 )
468490 else:
469491 password = body["password"]
470492 if not isinstance(password, str) or len(password) > 512:
471 raise SynapseError(400, "Invalid password")
493 raise SynapseError(HTTPStatus.BAD_REQUEST, "Invalid password")
472494
473495 password_bytes = password.encode("utf-8")
474496 if b"\x00" in password_bytes:
475 raise SynapseError(400, "Invalid password")
497 raise SynapseError(HTTPStatus.BAD_REQUEST, "Invalid password")
476498
477499 password_hash = await self.auth_handler.hash(password)
478500
481503 displayname = body.get("displayname", None)
482504
483505 if user_type is not None and user_type not in UserTypes.ALL_USER_TYPES:
484 raise SynapseError(400, "Invalid user type")
506 raise SynapseError(HTTPStatus.BAD_REQUEST, "Invalid user type")
485507
486508 if "mac" not in body:
487 raise SynapseError(400, "mac must be specified", errcode=Codes.BAD_JSON)
509 raise SynapseError(
510 HTTPStatus.BAD_REQUEST, "mac must be specified", errcode=Codes.BAD_JSON
511 )
488512
489513 got_mac = body["mac"]
490514
506530 want_mac = want_mac_builder.hexdigest()
507531
508532 if not hmac.compare_digest(want_mac.encode("ascii"), got_mac.encode("ascii")):
509 raise SynapseError(403, "HMAC incorrect")
533 raise SynapseError(HTTPStatus.FORBIDDEN, "HMAC incorrect")
510534
511535 # Reuse the parts of RegisterRestServlet to reduce code duplication
512536 from synapse.rest.client.register import RegisterRestServlet
523547 )
524548
525549 result = await register._create_registration_details(user_id, body)
526 return 200, result
550 return HTTPStatus.OK, result
527551
528552
529553 class WhoisRestServlet(RestServlet):
551575 await assert_user_is_admin(self.auth, auth_user)
552576
553577 if not self.hs.is_mine(target_user):
554 raise SynapseError(400, "Can only whois a local user")
578 raise SynapseError(HTTPStatus.BAD_REQUEST, "Can only whois a local user")
555579
556580 ret = await self.admin_handler.get_whois(target_user)
557581
558 return 200, ret
582 return HTTPStatus.OK, ret
559583
560584
561585 class DeactivateAccountRestServlet(RestServlet):
574598 await assert_user_is_admin(self.auth, requester.user)
575599
576600 if not self.is_mine(UserID.from_string(target_user_id)):
577 raise SynapseError(400, "Can only deactivate local users")
601 raise SynapseError(
602 HTTPStatus.BAD_REQUEST, "Can only deactivate local users"
603 )
578604
579605 if not await self.store.get_user_by_id(target_user_id):
580606 raise NotFoundError("User not found")
596622 else:
597623 id_server_unbind_result = "no-support"
598624
599 return 200, {"id_server_unbind_result": id_server_unbind_result}
625 return HTTPStatus.OK, {"id_server_unbind_result": id_server_unbind_result}
600626
601627
602628 class AccountValidityRenewServlet(RestServlet):
619645
620646 if "user_id" not in body:
621647 raise SynapseError(
622 400,
648 HTTPStatus.BAD_REQUEST,
623649 "Missing property 'user_id' in the request body",
624650 )
625651
630656 )
631657
632658 res = {"expiration_ts": expiration_ts}
633 return 200, res
659 return HTTPStatus.OK, res
634660
635661
636662 class ResetPasswordRestServlet(RestServlet):
677703 await self._set_password_handler.set_password(
678704 target_user_id, new_password_hash, logout_devices, requester
679705 )
680 return 200, {}
706 return HTTPStatus.OK, {}
681707
682708
683709 class SearchUsersRestServlet(RestServlet):
711737
712738 # To allow all users to get the users list
713739 # if not is_admin and target_user != auth_user:
714 # raise AuthError(403, "You are not a server admin")
740 # raise AuthError(HTTPStatus.FORBIDDEN, "You are not a server admin")
715741
716742 if not self.hs.is_mine(target_user):
717 raise SynapseError(400, "Can only users a local user")
743 raise SynapseError(HTTPStatus.BAD_REQUEST, "Can only users a local user")
718744
719745 term = parse_string(request, "term", required=True)
720746 logger.info("term: %s ", term)
721747
722748 ret = await self.store.search_users(term)
723 return 200, ret
749 return HTTPStatus.OK, ret
724750
725751
726752 class UserAdminServlet(RestServlet):
764790 target_user = UserID.from_string(user_id)
765791
766792 if not self.hs.is_mine(target_user):
767 raise SynapseError(400, "Only local users can be admins of this homeserver")
793 raise SynapseError(
794 HTTPStatus.BAD_REQUEST,
795 "Only local users can be admins of this homeserver",
796 )
768797
769798 is_admin = await self.store.is_server_admin(target_user)
770799
771 return 200, {"admin": is_admin}
800 return HTTPStatus.OK, {"admin": is_admin}
772801
773802 async def on_PUT(
774803 self, request: SynapseRequest, user_id: str
784813 assert_params_in_dict(body, ["admin"])
785814
786815 if not self.hs.is_mine(target_user):
787 raise SynapseError(400, "Only local users can be admins of this homeserver")
816 raise SynapseError(
817 HTTPStatus.BAD_REQUEST,
818 "Only local users can be admins of this homeserver",
819 )
788820
789821 set_admin_to = bool(body["admin"])
790822
791823 if target_user == auth_user and not set_admin_to:
792 raise SynapseError(400, "You may not demote yourself.")
824 raise SynapseError(HTTPStatus.BAD_REQUEST, "You may not demote yourself.")
793825
794826 await self.store.set_server_admin(target_user, set_admin_to)
795827
796 return 200, {}
828 return HTTPStatus.OK, {}
797829
798830
799831 class UserMembershipRestServlet(RestServlet):
815847
816848 room_ids = await self.store.get_rooms_for_user(user_id)
817849 ret = {"joined_rooms": list(room_ids), "total": len(room_ids)}
818 return 200, ret
850 return HTTPStatus.OK, ret
819851
820852
821853 class PushersRestServlet(RestServlet):
844876 await assert_requester_is_admin(self.auth, request)
845877
846878 if not self.is_mine(UserID.from_string(user_id)):
847 raise SynapseError(400, "Can only look up local users")
879 raise SynapseError(HTTPStatus.BAD_REQUEST, "Can only look up local users")
848880
849881 if not await self.store.get_user_by_id(user_id):
850882 raise NotFoundError("User not found")
853885
854886 filtered_pushers = [p.as_dict() for p in pushers]
855887
856 return 200, {"pushers": filtered_pushers, "total": len(filtered_pushers)}
888 return HTTPStatus.OK, {
889 "pushers": filtered_pushers,
890 "total": len(filtered_pushers),
891 }
857892
858893
859894 class UserTokenRestServlet(RestServlet):
886921 auth_user = requester.user
887922
888923 if not self.hs.is_mine_id(user_id):
889 raise SynapseError(400, "Only local users can be logged in as")
924 raise SynapseError(
925 HTTPStatus.BAD_REQUEST, "Only local users can be logged in as"
926 )
890927
891928 body = parse_json_object_from_request(request, allow_empty_body=True)
892929
893930 valid_until_ms = body.get("valid_until_ms")
894931 if valid_until_ms and not isinstance(valid_until_ms, int):
895 raise SynapseError(400, "'valid_until_ms' parameter must be an int")
932 raise SynapseError(
933 HTTPStatus.BAD_REQUEST, "'valid_until_ms' parameter must be an int"
934 )
896935
897936 if auth_user.to_string() == user_id:
898 raise SynapseError(400, "Cannot use admin API to login as self")
937 raise SynapseError(
938 HTTPStatus.BAD_REQUEST, "Cannot use admin API to login as self"
939 )
899940
900941 token = await self.auth_handler.create_access_token_for_user_id(
901942 user_id=auth_user.to_string(),
904945 puppets_user_id=user_id,
905946 )
906947
907 return 200, {"access_token": token}
948 return HTTPStatus.OK, {"access_token": token}
908949
909950
910951 class ShadowBanRestServlet(RestServlet):
946987 await assert_requester_is_admin(self.auth, request)
947988
948989 if not self.hs.is_mine_id(user_id):
949 raise SynapseError(400, "Only local users can be shadow-banned")
990 raise SynapseError(
991 HTTPStatus.BAD_REQUEST, "Only local users can be shadow-banned"
992 )
950993
951994 await self.store.set_shadow_banned(UserID.from_string(user_id), True)
952995
953 return 200, {}
996 return HTTPStatus.OK, {}
954997
955998 async def on_DELETE(
956999 self, request: SynapseRequest, user_id: str
9581001 await assert_requester_is_admin(self.auth, request)
9591002
9601003 if not self.hs.is_mine_id(user_id):
961 raise SynapseError(400, "Only local users can be shadow-banned")
1004 raise SynapseError(
1005 HTTPStatus.BAD_REQUEST, "Only local users can be shadow-banned"
1006 )
9621007
9631008 await self.store.set_shadow_banned(UserID.from_string(user_id), False)
9641009
965 return 200, {}
1010 return HTTPStatus.OK, {}
9661011
9671012
9681013 class RateLimitRestServlet(RestServlet):
9941039 await assert_requester_is_admin(self.auth, request)
9951040
9961041 if not self.hs.is_mine_id(user_id):
997 raise SynapseError(400, "Can only look up local users")
1042 raise SynapseError(HTTPStatus.BAD_REQUEST, "Can only look up local users")
9981043
9991044 if not await self.store.get_user_by_id(user_id):
10001045 raise NotFoundError("User not found")
10151060 else:
10161061 ret = {}
10171062
1018 return 200, ret
1063 return HTTPStatus.OK, ret
10191064
10201065 async def on_POST(
10211066 self, request: SynapseRequest, user_id: str
10231068 await assert_requester_is_admin(self.auth, request)
10241069
10251070 if not self.hs.is_mine_id(user_id):
1026 raise SynapseError(400, "Only local users can be ratelimited")
1071 raise SynapseError(
1072 HTTPStatus.BAD_REQUEST, "Only local users can be ratelimited"
1073 )
10271074
10281075 if not await self.store.get_user_by_id(user_id):
10291076 raise NotFoundError("User not found")
10351082
10361083 if not isinstance(messages_per_second, int) or messages_per_second < 0:
10371084 raise SynapseError(
1038 400,
1085 HTTPStatus.BAD_REQUEST,
10391086 "%r parameter must be a positive int" % (messages_per_second,),
10401087 errcode=Codes.INVALID_PARAM,
10411088 )
10421089
10431090 if not isinstance(burst_count, int) or burst_count < 0:
10441091 raise SynapseError(
1045 400,
1092 HTTPStatus.BAD_REQUEST,
10461093 "%r parameter must be a positive int" % (burst_count,),
10471094 errcode=Codes.INVALID_PARAM,
10481095 )
10581105 "burst_count": ratelimit.burst_count,
10591106 }
10601107
1061 return 200, ret
1108 return HTTPStatus.OK, ret
10621109
10631110 async def on_DELETE(
10641111 self, request: SynapseRequest, user_id: str
10661113 await assert_requester_is_admin(self.auth, request)
10671114
10681115 if not self.hs.is_mine_id(user_id):
1069 raise SynapseError(400, "Only local users can be ratelimited")
1116 raise SynapseError(
1117 HTTPStatus.BAD_REQUEST, "Only local users can be ratelimited"
1118 )
10701119
10711120 if not await self.store.get_user_by_id(user_id):
10721121 raise NotFoundError("User not found")
10731122
10741123 await self.store.delete_ratelimit_for_user(user_id)
10751124
1076 return 200, {}
1125 return HTTPStatus.OK, {}
1313
1414 import logging
1515 import re
16 from typing import TYPE_CHECKING, Any, Awaitable, Callable, Dict, List, Optional, Tuple
16 from typing import (
17 TYPE_CHECKING,
18 Any,
19 Awaitable,
20 Callable,
21 Dict,
22 List,
23 Optional,
24 Tuple,
25 Union,
26 )
1727
1828 from typing_extensions import TypedDict
1929
2737 from synapse.http.servlet import (
2838 RestServlet,
2939 assert_params_in_dict,
30 parse_boolean,
3140 parse_bytes_from_args,
3241 parse_json_object_from_request,
3342 parse_string,
6271 JWT_TYPE_DEPRECATED = "m.login.jwt"
6372 APPSERVICE_TYPE = "m.login.application_service"
6473 APPSERVICE_TYPE_UNSTABLE = "uk.half-shot.msc2778.login.application_service"
65 REFRESH_TOKEN_PARAM = "org.matrix.msc2918.refresh_token"
74 REFRESH_TOKEN_PARAM = "refresh_token"
6675
6776 def __init__(self, hs: "HomeServer"):
6877 super().__init__()
8089 self.saml2_enabled = hs.config.saml2.saml2_enabled
8190 self.cas_enabled = hs.config.cas.cas_enabled
8291 self.oidc_enabled = hs.config.oidc.oidc_enabled
83 self._msc2918_enabled = (
92 self._refresh_tokens_enabled = (
8493 hs.config.registration.refreshable_access_token_lifetime is not None
8594 )
8695
153162 async def on_POST(self, request: SynapseRequest) -> Tuple[int, LoginResponse]:
154163 login_submission = parse_json_object_from_request(request)
155164
156 if self._msc2918_enabled:
157 # Check if this login should also issue a refresh token, as per
158 # MSC2918
159 should_issue_refresh_token = parse_boolean(
160 request, name=LoginRestServlet.REFRESH_TOKEN_PARAM, default=False
161 )
162 else:
163 should_issue_refresh_token = False
165 # Check to see if the client requested a refresh token.
166 client_requested_refresh_token = login_submission.get(
167 LoginRestServlet.REFRESH_TOKEN_PARAM, False
168 )
169 if not isinstance(client_requested_refresh_token, bool):
170 raise SynapseError(400, "`refresh_token` should be true or false.")
171
172 should_issue_refresh_token = (
173 self._refresh_tokens_enabled and client_requested_refresh_token
174 )
164175
165176 try:
166177 if login_submission["type"] in (
290301 ratelimit: bool = True,
291302 auth_provider_id: Optional[str] = None,
292303 should_issue_refresh_token: bool = False,
304 auth_provider_session_id: Optional[str] = None,
293305 ) -> LoginResponse:
294306 """Called when we've successfully authed the user and now need to
295307 actually login them in (e.g. create devices). This gets called on
305317 create_non_existent_users: Whether to create the user if they don't
306318 exist. Defaults to False.
307319 ratelimit: Whether to ratelimit the login request.
308 auth_provider_id: The SSO IdP the user used, if any (just used for the
309 prometheus metrics).
320 auth_provider_id: The SSO IdP the user used, if any.
310321 should_issue_refresh_token: True if this login should issue
311322 a refresh token alongside the access token.
323 auth_provider_session_id: The session ID got during login from the SSO IdP.
312324
313325 Returns:
314326 result: Dictionary of account information after successful login.
341353 initial_display_name,
342354 auth_provider_id=auth_provider_id,
343355 should_issue_refresh_token=should_issue_refresh_token,
356 auth_provider_session_id=auth_provider_session_id,
344357 )
345358
346359 result = LoginResponse(
386399 self.auth_handler._sso_login_callback,
387400 auth_provider_id=res.auth_provider_id,
388401 should_issue_refresh_token=should_issue_refresh_token,
402 auth_provider_session_id=res.auth_provider_session_id,
389403 )
390404
391405 async def _do_jwt_login(
447461
448462
449463 class RefreshTokenServlet(RestServlet):
450 PATTERNS = client_patterns(
451 "/org.matrix.msc2918.refresh_token/refresh$", releases=(), unstable=True
452 )
464 PATTERNS = (re.compile("^/_matrix/client/v1/refresh$"),)
453465
454466 def __init__(self, hs: "HomeServer"):
455467 self._auth_handler = hs.get_auth_handler()
457469 self.refreshable_access_token_lifetime = (
458470 hs.config.registration.refreshable_access_token_lifetime
459471 )
472 self.refresh_token_lifetime = hs.config.registration.refresh_token_lifetime
460473
461474 async def on_POST(self, request: SynapseRequest) -> Tuple[int, JsonDict]:
462475 refresh_submission = parse_json_object_from_request(request)
466479 if not isinstance(token, str):
467480 raise SynapseError(400, "Invalid param: refresh_token", Codes.INVALID_PARAM)
468481
469 valid_until_ms = (
470 self._clock.time_msec() + self.refreshable_access_token_lifetime
471 )
472 access_token, refresh_token = await self._auth_handler.refresh_token(
473 token, valid_until_ms
474 )
475 expires_in_ms = valid_until_ms - self._clock.time_msec()
476 return (
477 200,
478 {
479 "access_token": access_token,
480 "refresh_token": refresh_token,
481 "expires_in_ms": expires_in_ms,
482 },
483 )
482 now = self._clock.time_msec()
483 access_valid_until_ms = None
484 if self.refreshable_access_token_lifetime is not None:
485 access_valid_until_ms = now + self.refreshable_access_token_lifetime
486 refresh_valid_until_ms = None
487 if self.refresh_token_lifetime is not None:
488 refresh_valid_until_ms = now + self.refresh_token_lifetime
489
490 (
491 access_token,
492 refresh_token,
493 actual_access_token_expiry,
494 ) = await self._auth_handler.refresh_token(
495 token, access_valid_until_ms, refresh_valid_until_ms
496 )
497
498 response: Dict[str, Union[str, int]] = {
499 "access_token": access_token,
500 "refresh_token": refresh_token,
501 }
502
503 # expires_in_ms is only present if the token expires
504 if actual_access_token_expiry is not None:
505 response["expires_in_ms"] = actual_access_token_expiry - now
506
507 return 200, response
484508
485509
486510 class SsoRedirectServlet(RestServlet):
488512 re.compile(
489513 "^"
490514 + CLIENT_API_PREFIX
491 + "/r0/login/sso/redirect/(?P<idp_id>[A-Za-z0-9_.~-]+)$"
515 + "/(r0|v3)/login/sso/redirect/(?P<idp_id>[A-Za-z0-9_.~-]+)$"
492516 )
493517 ]
494518
4040 from synapse.http.servlet import (
4141 RestServlet,
4242 assert_params_in_dict,
43 parse_boolean,
4443 parse_json_object_from_request,
4544 parse_string,
4645 )
419418 self.password_policy_handler = hs.get_password_policy_handler()
420419 self.clock = hs.get_clock()
421420 self._registration_enabled = self.hs.config.registration.enable_registration
422 self._msc2918_enabled = (
421 self._refresh_tokens_enabled = (
423422 hs.config.registration.refreshable_access_token_lifetime is not None
424423 )
425424
445444 f"Do not understand membership kind: {kind}",
446445 )
447446
448 if self._msc2918_enabled:
449 # Check if this registration should also issue a refresh token, as
450 # per MSC2918
451 should_issue_refresh_token = parse_boolean(
452 request, name="org.matrix.msc2918.refresh_token", default=False
453 )
454 else:
455 should_issue_refresh_token = False
447 # Check if the clients wishes for this registration to issue a refresh
448 # token.
449 client_requested_refresh_tokens = body.get("refresh_token", False)
450 if not isinstance(client_requested_refresh_tokens, bool):
451 raise SynapseError(400, "`refresh_token` should be true or false.")
452
453 should_issue_refresh_token = (
454 self._refresh_tokens_enabled and client_requested_refresh_tokens
455 )
456456
457457 # Pull out the provided username and do basic sanity checks early since
458458 # the auth layer will store these in sessions.
223223 )
224224
225225 now = self.clock.time_msec()
226 # We set bundle_relations to False when retrieving the original
227 # event because we want the content before relations were applied to
228 # it.
226 # Do not bundle aggregations when retrieving the original event because
227 # we want the content before relations are applied to it.
229228 original_event = await self._event_serializer.serialize_event(
230 event, now, bundle_relations=False
231 )
232 # Similarly, we don't allow relations to be applied to relations, so we
233 # return the original relations without any aggregations on top of them
234 # here.
235 serialized_events = await self._event_serializer.serialize_events(
236 events, now, bundle_relations=False
237 )
229 event, now, bundle_aggregations=False
230 )
231 # The relations returned for the requested event do include their
232 # bundled aggregations.
233 serialized_events = await self._event_serializer.serialize_events(events, now)
238234
239235 return_value = pagination_chunk.to_dict()
240236 return_value["chunk"] = serialized_events
715715 results["events_after"], time_now
716716 )
717717 results["state"] = await self._event_serializer.serialize_events(
718 results["state"],
719 time_now,
720 # No need to bundle aggregations for state events
721 bundle_relations=False,
718 results["state"], time_now
722719 )
723720
724721 return 200, results
10691066 )
10701067
10711068
1069 class TimestampLookupRestServlet(RestServlet):
1070 """
1071 API endpoint to fetch the `event_id` of the closest event to the given
1072 timestamp (`ts` query parameter) in the given direction (`dir` query
1073 parameter).
1074
1075 Useful for cases like jump to date so you can start paginating messages from
1076 a given date in the archive.
1077
1078 `ts` is a timestamp in milliseconds where we will find the closest event in
1079 the given direction.
1080
1081 `dir` can be `f` or `b` to indicate forwards and backwards in time from the
1082 given timestamp.
1083
1084 GET /_matrix/client/unstable/org.matrix.msc3030/rooms/<roomID>/timestamp_to_event?ts=<timestamp>&dir=<direction>
1085 {
1086 "event_id": ...
1087 }
1088 """
1089
1090 PATTERNS = (
1091 re.compile(
1092 "^/_matrix/client/unstable/org.matrix.msc3030"
1093 "/rooms/(?P<room_id>[^/]*)/timestamp_to_event$"
1094 ),
1095 )
1096
1097 def __init__(self, hs: "HomeServer"):
1098 super().__init__()
1099 self._auth = hs.get_auth()
1100 self._store = hs.get_datastore()
1101 self.timestamp_lookup_handler = hs.get_timestamp_lookup_handler()
1102
1103 async def on_GET(
1104 self, request: SynapseRequest, room_id: str
1105 ) -> Tuple[int, JsonDict]:
1106 requester = await self._auth.get_user_by_req(request)
1107 await self._auth.check_user_in_room(room_id, requester.user.to_string())
1108
1109 timestamp = parse_integer(request, "ts", required=True)
1110 direction = parse_string(request, "dir", default="f", allowed_values=["f", "b"])
1111
1112 (
1113 event_id,
1114 origin_server_ts,
1115 ) = await self.timestamp_lookup_handler.get_event_for_timestamp(
1116 requester, room_id, timestamp, direction
1117 )
1118
1119 return 200, {
1120 "event_id": event_id,
1121 "origin_server_ts": origin_server_ts,
1122 }
1123
1124
10721125 class RoomSpaceSummaryRestServlet(RestServlet):
10731126 PATTERNS = (
10741127 re.compile(
11391192 class RoomHierarchyRestServlet(RestServlet):
11401193 PATTERNS = (
11411194 re.compile(
1142 "^/_matrix/client/unstable/org.matrix.msc2946"
1195 "^/_matrix/client/(v1|unstable/org.matrix.msc2946)"
11431196 "/rooms/(?P<room_id>[^/]*)/hierarchy$"
11441197 ),
11451198 )
11671220 )
11681221
11691222 return 200, await self._room_summary_handler.get_room_hierarchy(
1170 requester.user.to_string(),
1223 requester,
11711224 room_id,
11721225 suggested_only=parse_boolean(request, "suggested_only", default=False),
11731226 max_depth=max_depth,
12381291 RoomAliasListServlet(hs).register(http_server)
12391292 SearchRestServlet(hs).register(http_server)
12401293 RoomCreateRestServlet(hs).register(http_server)
1294 if hs.config.experimental.msc3030_enabled:
1295 TimestampLookupRestServlet(hs).register(http_server)
12411296
12421297 # Some servlets only get registered for the main process.
12431298 if not is_worker:
519519 return self._event_serializer.serialize_events(
520520 events,
521521 time_now=time_now,
522 # We don't bundle "live" events, as otherwise clients
523 # will end up double counting annotations.
524 bundle_relations=False,
522 # Don't bother to bundle aggregations if the timeline is unlimited,
523 # as clients will have all the necessary information.
524 bundle_aggregations=room.timeline.limited,
525525 token_id=token_id,
526526 event_format=event_formatter,
527527 only_event_fields=only_fields,
4242 )
4343
4444
45 def _wrap_with_jail_check(func: GetPathMethod) -> GetPathMethod:
45 def _wrap_with_jail_check(relative: bool) -> Callable[[GetPathMethod], GetPathMethod]:
4646 """Wraps a path-returning method to check that the returned path(s) do not escape
4747 the media store directory.
4848
49 The path-returning method may return either a single path, or a list of paths.
50
4951 The check is not expected to ever fail, unless `func` is missing a call to
5052 `_validate_path_component`, or `_validate_path_component` is buggy.
5153
5254 Args:
53 func: The `MediaFilePaths` method to wrap. The method may return either a single
54 path, or a list of paths. Returned paths may be either absolute or relative.
55 relative: A boolean indicating whether the wrapped method returns paths relative
56 to the media store directory.
5557
5658 Returns:
57 The method, wrapped with a check to ensure that the returned path(s) lie within
58 the media store directory. Raises a `ValueError` if the check fails.
59 A method which will wrap a path-returning method, adding a check to ensure that
60 the returned path(s) lie within the media store directory. The check will raise
61 a `ValueError` if it fails.
5962 """
6063
61 @functools.wraps(func)
62 def _wrapped(
63 self: "MediaFilePaths", *args: Any, **kwargs: Any
64 ) -> Union[str, List[str]]:
65 path_or_paths = func(self, *args, **kwargs)
66
67 if isinstance(path_or_paths, list):
68 paths_to_check = path_or_paths
69 else:
70 paths_to_check = [path_or_paths]
71
72 for path in paths_to_check:
73 # path may be an absolute or relative path, depending on the method being
74 # wrapped. When "appending" an absolute path, `os.path.join` discards the
75 # previous path, which is desired here.
76 normalized_path = os.path.normpath(os.path.join(self.real_base_path, path))
77 if (
78 os.path.commonpath([normalized_path, self.real_base_path])
79 != self.real_base_path
80 ):
81 raise ValueError(f"Invalid media store path: {path!r}")
82
83 return path_or_paths
84
85 return cast(GetPathMethod, _wrapped)
64 def _wrap_with_jail_check_inner(func: GetPathMethod) -> GetPathMethod:
65 @functools.wraps(func)
66 def _wrapped(
67 self: "MediaFilePaths", *args: Any, **kwargs: Any
68 ) -> Union[str, List[str]]:
69 path_or_paths = func(self, *args, **kwargs)
70
71 if isinstance(path_or_paths, list):
72 paths_to_check = path_or_paths
73 else:
74 paths_to_check = [path_or_paths]
75
76 for path in paths_to_check:
77 # Construct the path that will ultimately be used.
78 # We cannot guess whether `path` is relative to the media store
79 # directory, since the media store directory may itself be a relative
80 # path.
81 if relative:
82 path = os.path.join(self.base_path, path)
83 normalized_path = os.path.normpath(path)
84
85 # Now that `normpath` has eliminated `../`s and `./`s from the path,
86 # `os.path.commonpath` can be used to check whether it lies within the
87 # media store directory.
88 if (
89 os.path.commonpath([normalized_path, self.normalized_base_path])
90 != self.normalized_base_path
91 ):
92 # The path resolves to outside the media store directory,
93 # or `self.base_path` is `.`, which is an unlikely configuration.
94 raise ValueError(f"Invalid media store path: {path!r}")
95
96 # Note that `os.path.normpath`/`abspath` has a subtle caveat:
97 # `a/b/c/../c` will normalize to `a/b/c`, but the former refers to a
98 # different path if `a/b/c` is a symlink. That is, the check above is
99 # not perfect and may allow a certain restricted subset of untrustworthy
100 # paths through. Since the check above is secondary to the main
101 # `_validate_path_component` checks, it's less important for it to be
102 # perfect.
103 #
104 # As an alternative, `os.path.realpath` will resolve symlinks, but
105 # proves problematic if there are symlinks inside the media store.
106 # eg. if `url_store/` is symlinked to elsewhere, its canonical path
107 # won't match that of the main media store directory.
108
109 return path_or_paths
110
111 return cast(GetPathMethod, _wrapped)
112
113 return _wrap_with_jail_check_inner
86114
87115
88116 ALLOWED_CHARACTERS = set(
126154
127155 def __init__(self, primary_base_path: str):
128156 self.base_path = primary_base_path
129
130 # The media store directory, with all symlinks resolved.
131 self.real_base_path = os.path.realpath(primary_base_path)
157 self.normalized_base_path = os.path.normpath(self.base_path)
132158
133159 # Refuse to initialize if paths cannot be validated correctly for the current
134160 # platform.
139165 # for certain homeservers there, since ":"s aren't allowed in paths.
140166 assert os.name == "posix"
141167
142 @_wrap_with_jail_check
168 @_wrap_with_jail_check(relative=True)
143169 def local_media_filepath_rel(self, media_id: str) -> str:
144170 return os.path.join(
145171 "local_content",
150176
151177 local_media_filepath = _wrap_in_base_path(local_media_filepath_rel)
152178
153 @_wrap_with_jail_check
179 @_wrap_with_jail_check(relative=True)
154180 def local_media_thumbnail_rel(
155181 self, media_id: str, width: int, height: int, content_type: str, method: str
156182 ) -> str:
166192
167193 local_media_thumbnail = _wrap_in_base_path(local_media_thumbnail_rel)
168194
169 @_wrap_with_jail_check
195 @_wrap_with_jail_check(relative=False)
170196 def local_media_thumbnail_dir(self, media_id: str) -> str:
171197 """
172198 Retrieve the local store path of thumbnails of a given media_id
184210 _validate_path_component(media_id[4:]),
185211 )
186212
187 @_wrap_with_jail_check
213 @_wrap_with_jail_check(relative=True)
188214 def remote_media_filepath_rel(self, server_name: str, file_id: str) -> str:
189215 return os.path.join(
190216 "remote_content",
196222
197223 remote_media_filepath = _wrap_in_base_path(remote_media_filepath_rel)
198224
199 @_wrap_with_jail_check
225 @_wrap_with_jail_check(relative=True)
200226 def remote_media_thumbnail_rel(
201227 self,
202228 server_name: str,
222248 # Legacy path that was used to store thumbnails previously.
223249 # Should be removed after some time, when most of the thumbnails are stored
224250 # using the new path.
225 @_wrap_with_jail_check
251 @_wrap_with_jail_check(relative=True)
226252 def remote_media_thumbnail_rel_legacy(
227253 self, server_name: str, file_id: str, width: int, height: int, content_type: str
228254 ) -> str:
237263 _validate_path_component(file_name),
238264 )
239265
266 @_wrap_with_jail_check(relative=False)
240267 def remote_media_thumbnail_dir(self, server_name: str, file_id: str) -> str:
241268 return os.path.join(
242269 self.base_path,
247274 _validate_path_component(file_id[4:]),
248275 )
249276
250 @_wrap_with_jail_check
277 @_wrap_with_jail_check(relative=True)
251278 def url_cache_filepath_rel(self, media_id: str) -> str:
252279 if NEW_FORMAT_ID_RE.match(media_id):
253280 # Media id is of the form <DATE><RANDOM_STRING>
267294
268295 url_cache_filepath = _wrap_in_base_path(url_cache_filepath_rel)
269296
270 @_wrap_with_jail_check
297 @_wrap_with_jail_check(relative=False)
271298 def url_cache_filepath_dirs_to_delete(self, media_id: str) -> List[str]:
272299 "The dirs to try and remove if we delete the media_id file"
273300 if NEW_FORMAT_ID_RE.match(media_id):
289316 ),
290317 ]
291318
292 @_wrap_with_jail_check
319 @_wrap_with_jail_check(relative=True)
293320 def url_cache_thumbnail_rel(
294321 self, media_id: str, width: int, height: int, content_type: str, method: str
295322 ) -> str:
317344
318345 url_cache_thumbnail = _wrap_in_base_path(url_cache_thumbnail_rel)
319346
320 @_wrap_with_jail_check
347 @_wrap_with_jail_check(relative=True)
321348 def url_cache_thumbnail_directory_rel(self, media_id: str) -> str:
322349 # Media id is of the form <DATE><RANDOM_STRING>
323350 # E.g.: 2017-09-28-fsdRDt24DS234dsf
340367 url_cache_thumbnail_directory_rel
341368 )
342369
343 @_wrap_with_jail_check
370 @_wrap_with_jail_check(relative=False)
344371 def url_cache_thumbnail_dirs_to_delete(self, media_id: str) -> List[str]:
345372 "The dirs to try and remove if we delete the media_id thumbnails"
346373 # Media id is of the form <DATE><RANDOM_STRING>
9696 RoomContextHandler,
9797 RoomCreationHandler,
9898 RoomShutdownHandler,
99 TimestampLookupHandler,
99100 )
100101 from synapse.handlers.room_batch import RoomBatchHandler
101102 from synapse.handlers.room_list import RoomListHandler
728729 return RoomContextHandler(self)
729730
730731 @cache_in_self
732 def get_timestamp_lookup_handler(self) -> TimestampLookupHandler:
733 return TimestampLookupHandler(self)
734
735 @cache_in_self
731736 def get_registration_handler(self) -> RegistrationHandler:
732737 return RegistrationHandler(self)
733738
763763 store: "DataStore"
764764
765765 def get_events(
766 self, event_ids: Iterable[str], allow_rejected: bool = False
766 self, event_ids: Collection[str], allow_rejected: bool = False
767767 ) -> Awaitable[Dict[str, EventBase]]:
768768 """Get events from the database
769769
1616 from typing import (
1717 Awaitable,
1818 Callable,
19 Collection,
1920 Dict,
2021 Iterable,
2122 List,
4344 room_version: RoomVersion,
4445 state_sets: Sequence[StateMap[str]],
4546 event_map: Optional[Dict[str, EventBase]],
46 state_map_factory: Callable[[Iterable[str]], Awaitable[Dict[str, EventBase]]],
47 state_map_factory: Callable[[Collection[str]], Awaitable[Dict[str, EventBase]]],
4748 ) -> StateMap[str]:
4849 """
4950 Args:
2020 from synapse.storage.database import make_in_list_sql_clause # noqa: F401
2121 from synapse.storage.database import DatabasePool
2222 from synapse.storage.types import Connection
23 from synapse.types import StreamToken, get_domain_from_id
23 from synapse.types import get_domain_from_id
2424 from synapse.util import json_decoder
2525
2626 if TYPE_CHECKING:
4747 self,
4848 stream_name: str,
4949 instance_name: str,
50 token: StreamToken,
50 token: int,
5151 rows: Iterable[Any],
5252 ) -> None:
5353 pass
1111 # See the License for the specific language governing permissions and
1212 # limitations under the License.
1313 import logging
14 from typing import TYPE_CHECKING, Awaitable, Callable, Dict, Iterable, Optional
14 from typing import (
15 TYPE_CHECKING,
16 AsyncContextManager,
17 Awaitable,
18 Callable,
19 Dict,
20 Iterable,
21 Optional,
22 )
23
24 import attr
1525
1626 from synapse.metrics.background_process_metrics import run_as_background_process
1727 from synapse.storage.types import Connection
1828 from synapse.types import JsonDict
19 from synapse.util import json_encoder
29 from synapse.util import Clock, json_encoder
2030
2131 from . import engines
2232
2535 from synapse.storage.database import DatabasePool, LoggingTransaction
2636
2737 logger = logging.getLogger(__name__)
38
39
40 ON_UPDATE_CALLBACK = Callable[[str, str, bool], AsyncContextManager[int]]
41 DEFAULT_BATCH_SIZE_CALLBACK = Callable[[str, str], Awaitable[int]]
42 MIN_BATCH_SIZE_CALLBACK = Callable[[str, str], Awaitable[int]]
43
44
45 @attr.s(slots=True, frozen=True, auto_attribs=True)
46 class _BackgroundUpdateHandler:
47 """A handler for a given background update.
48
49 Attributes:
50 callback: The function to call to make progress on the background
51 update.
52 oneshot: Wether the update is likely to happen all in one go, ignoring
53 the supplied target duration, e.g. index creation. This is used by
54 the update controller to help correctly schedule the update.
55 """
56
57 callback: Callable[[JsonDict, int], Awaitable[int]]
58 oneshot: bool = False
59
60
61 class _BackgroundUpdateContextManager:
62 BACKGROUND_UPDATE_INTERVAL_MS = 1000
63 BACKGROUND_UPDATE_DURATION_MS = 100
64
65 def __init__(self, sleep: bool, clock: Clock):
66 self._sleep = sleep
67 self._clock = clock
68
69 async def __aenter__(self) -> int:
70 if self._sleep:
71 await self._clock.sleep(self.BACKGROUND_UPDATE_INTERVAL_MS / 1000)
72
73 return self.BACKGROUND_UPDATE_DURATION_MS
74
75 async def __aexit__(self, *exc) -> None:
76 pass
2877
2978
3079 class BackgroundUpdatePerformance:
83132
84133 MINIMUM_BACKGROUND_BATCH_SIZE = 1
85134 DEFAULT_BACKGROUND_BATCH_SIZE = 100
86 BACKGROUND_UPDATE_INTERVAL_MS = 1000
87 BACKGROUND_UPDATE_DURATION_MS = 100
88135
89136 def __init__(self, hs: "HomeServer", database: "DatabasePool"):
90137 self._clock = hs.get_clock()
91138 self.db_pool = database
92139
140 self._database_name = database.name()
141
93142 # if a background update is currently running, its name.
94143 self._current_background_update: Optional[str] = None
95144
145 self._on_update_callback: Optional[ON_UPDATE_CALLBACK] = None
146 self._default_batch_size_callback: Optional[DEFAULT_BATCH_SIZE_CALLBACK] = None
147 self._min_batch_size_callback: Optional[MIN_BATCH_SIZE_CALLBACK] = None
148
96149 self._background_update_performance: Dict[str, BackgroundUpdatePerformance] = {}
97 self._background_update_handlers: Dict[
98 str, Callable[[JsonDict, int], Awaitable[int]]
99 ] = {}
150 self._background_update_handlers: Dict[str, _BackgroundUpdateHandler] = {}
100151 self._all_done = False
101152
102153 # Whether we're currently running updates
105156 # Whether background updates are enabled. This allows us to
106157 # enable/disable background updates via the admin API.
107158 self.enabled = True
159
160 def register_update_controller_callbacks(
161 self,
162 on_update: ON_UPDATE_CALLBACK,
163 default_batch_size: Optional[DEFAULT_BATCH_SIZE_CALLBACK] = None,
164 min_batch_size: Optional[DEFAULT_BATCH_SIZE_CALLBACK] = None,
165 ) -> None:
166 """Register callbacks from a module for each hook."""
167 if self._on_update_callback is not None:
168 logger.warning(
169 "More than one module tried to register callbacks for controlling"
170 " background updates. Only the callbacks registered by the first module"
171 " (in order of appearance in Synapse's configuration file) that tried to"
172 " do so will be called."
173 )
174
175 return
176
177 self._on_update_callback = on_update
178
179 if default_batch_size is not None:
180 self._default_batch_size_callback = default_batch_size
181
182 if min_batch_size is not None:
183 self._min_batch_size_callback = min_batch_size
184
185 def _get_context_manager_for_update(
186 self,
187 sleep: bool,
188 update_name: str,
189 database_name: str,
190 oneshot: bool,
191 ) -> AsyncContextManager[int]:
192 """Get a context manager to run a background update with.
193
194 If a module has registered a `update_handler` callback, use the context manager
195 it returns.
196
197 Otherwise, returns a context manager that will return a default value, optionally
198 sleeping if needed.
199
200 Args:
201 sleep: Whether we can sleep between updates.
202 update_name: The name of the update.
203 database_name: The name of the database the update is being run on.
204 oneshot: Whether the update will complete all in one go, e.g. index creation.
205 In such cases the returned target duration is ignored.
206
207 Returns:
208 The target duration in milliseconds that the background update should run for.
209
210 Note: this is a *target*, and an iteration may take substantially longer or
211 shorter.
212 """
213 if self._on_update_callback is not None:
214 return self._on_update_callback(update_name, database_name, oneshot)
215
216 return _BackgroundUpdateContextManager(sleep, self._clock)
217
218 async def _default_batch_size(self, update_name: str, database_name: str) -> int:
219 """The batch size to use for the first iteration of a new background
220 update.
221 """
222 if self._default_batch_size_callback is not None:
223 return await self._default_batch_size_callback(update_name, database_name)
224
225 return self.DEFAULT_BACKGROUND_BATCH_SIZE
226
227 async def _min_batch_size(self, update_name: str, database_name: str) -> int:
228 """A lower bound on the batch size of a new background update.
229
230 Used to ensure that progress is always made. Must be greater than 0.
231 """
232 if self._min_batch_size_callback is not None:
233 return await self._min_batch_size_callback(update_name, database_name)
234
235 return self.MINIMUM_BACKGROUND_BATCH_SIZE
108236
109237 def get_current_update(self) -> Optional[BackgroundUpdatePerformance]:
110238 """Returns the current background update, if any."""
134262 try:
135263 logger.info("Starting background schema updates")
136264 while self.enabled:
137 if sleep:
138 await self._clock.sleep(self.BACKGROUND_UPDATE_INTERVAL_MS / 1000.0)
139
140265 try:
141 result = await self.do_next_background_update(
142 self.BACKGROUND_UPDATE_DURATION_MS
143 )
266 result = await self.do_next_background_update(sleep)
144267 except Exception:
145268 logger.exception("Error doing update")
146269 else:
202325
203326 return not update_exists
204327
205 async def do_next_background_update(self, desired_duration_ms: float) -> bool:
328 async def do_next_background_update(self, sleep: bool = True) -> bool:
206329 """Does some amount of work on the next queued background update
207330
208331 Returns once some amount of work is done.
209332
210333 Args:
211 desired_duration_ms: How long we want to spend updating.
334 sleep: Whether to limit how quickly we run background updates or
335 not.
336
212337 Returns:
213338 True if we have finished running all the background updates, otherwise False
214339 """
251376
252377 self._current_background_update = upd["update_name"]
253378
254 await self._do_background_update(desired_duration_ms)
379 # We have a background update to run, otherwise we would have returned
380 # early.
381 assert self._current_background_update is not None
382 update_info = self._background_update_handlers[self._current_background_update]
383
384 async with self._get_context_manager_for_update(
385 sleep=sleep,
386 update_name=self._current_background_update,
387 database_name=self._database_name,
388 oneshot=update_info.oneshot,
389 ) as desired_duration_ms:
390 await self._do_background_update(desired_duration_ms)
391
255392 return False
256393
257394 async def _do_background_update(self, desired_duration_ms: float) -> int:
259396 update_name = self._current_background_update
260397 logger.info("Starting update batch on background update '%s'", update_name)
261398
262 update_handler = self._background_update_handlers[update_name]
399 update_handler = self._background_update_handlers[update_name].callback
263400
264401 performance = self._background_update_performance.get(update_name)
265402
272409 if items_per_ms is not None:
273410 batch_size = int(desired_duration_ms * items_per_ms)
274411 # Clamp the batch size so that we always make progress
275 batch_size = max(batch_size, self.MINIMUM_BACKGROUND_BATCH_SIZE)
412 batch_size = max(
413 batch_size,
414 await self._min_batch_size(update_name, self._database_name),
415 )
276416 else:
277 batch_size = self.DEFAULT_BACKGROUND_BATCH_SIZE
417 batch_size = await self._default_batch_size(
418 update_name, self._database_name
419 )
278420
279421 progress_json = await self.db_pool.simple_select_one_onecol(
280422 "background_updates",
292434 time_stop = self._clock.time_msec()
293435
294436 duration_ms = time_stop - time_start
437
438 performance.update(items_updated, duration_ms)
295439
296440 logger.info(
297441 "Running background update %r. Processed %r items in %rms."
305449 batch_size,
306450 )
307451
308 performance.update(items_updated, duration_ms)
309
310452 return len(self._background_update_performance)
311453
312454 def register_background_update_handler(
330472 update_name: The name of the update that this code handles.
331473 update_handler: The function that does the update.
332474 """
333 self._background_update_handlers[update_name] = update_handler
475 self._background_update_handlers[update_name] = _BackgroundUpdateHandler(
476 update_handler
477 )
334478
335479 def register_noop_background_update(self, update_name: str) -> None:
336480 """Register a noop handler for a background update.
452596 await self._end_background_update(update_name)
453597 return 1
454598
455 self.register_background_update_handler(update_name, updater)
599 self._background_update_handlers[update_name] = _BackgroundUpdateHandler(
600 updater, oneshot=True
601 )
456602
457603 async def _end_background_update(self, update_name: str) -> None:
458604 """Removes a completed background update task from the queue.
142142 A list of ApplicationServices, which may be empty.
143143 """
144144 results = await self.db_pool.simple_select_list(
145 "application_services_state", {"state": state}, ["as_id"]
145 "application_services_state", {"state": state.value}, ["as_id"]
146146 )
147147 # NB: This assumes this class is linked with ApplicationServiceStore
148148 as_list = self.get_app_services()
172172 desc="get_appservice_state",
173173 )
174174 if result:
175 return result.get("state")
175 return ApplicationServiceState(result.get("state"))
176176 return None
177177
178178 async def set_appservice_state(
185185 state: The connectivity state to apply.
186186 """
187187 await self.db_pool.simple_upsert(
188 "application_services_state", {"as_id": service.id}, {"state": state}
188 "application_services_state", {"as_id": service.id}, {"state": state.value}
189189 )
190190
191191 async def create_appservice_txn(
138138
139139 return {d["device_id"]: d for d in devices}
140140
141 async def get_devices_by_auth_provider_session_id(
142 self, auth_provider_id: str, auth_provider_session_id: str
143 ) -> List[Dict[str, Any]]:
144 """Retrieve the list of devices associated with a SSO IdP session ID.
145
146 Args:
147 auth_provider_id: The SSO IdP ID as defined in the server config
148 auth_provider_session_id: The session ID within the IdP
149 Returns:
150 A list of dicts containing the device_id and the user_id of each device
151 """
152 return await self.db_pool.simple_select_list(
153 table="device_auth_providers",
154 keyvalues={
155 "auth_provider_id": auth_provider_id,
156 "auth_provider_session_id": auth_provider_session_id,
157 },
158 retcols=("user_id", "device_id"),
159 desc="get_devices_by_auth_provider_session_id",
160 )
161
141162 @trace
142163 async def get_device_updates_by_remote(
143164 self, destination: str, from_stream_id: int, limit: int
10691090 )
10701091
10711092 async def store_device(
1072 self, user_id: str, device_id: str, initial_device_display_name: Optional[str]
1093 self,
1094 user_id: str,
1095 device_id: str,
1096 initial_device_display_name: Optional[str],
1097 auth_provider_id: Optional[str] = None,
1098 auth_provider_session_id: Optional[str] = None,
10731099 ) -> bool:
10741100 """Ensure the given device is known; add it to the store if not
10751101
10781104 device_id: id of device
10791105 initial_device_display_name: initial displayname of the device.
10801106 Ignored if device exists.
1107 auth_provider_id: The SSO IdP the user used, if any.
1108 auth_provider_session_id: The session ID (sid) got from a OIDC login.
10811109
10821110 Returns:
10831111 Whether the device was inserted or an existing device existed with that ID.
11131141 )
11141142 if hidden:
11151143 raise StoreError(400, "The device ID is in use", Codes.FORBIDDEN)
1144
1145 if auth_provider_id and auth_provider_session_id:
1146 await self.db_pool.simple_insert(
1147 "device_auth_providers",
1148 values={
1149 "user_id": user_id,
1150 "device_id": device_id,
1151 "auth_provider_id": auth_provider_id,
1152 "auth_provider_session_id": auth_provider_session_id,
1153 },
1154 desc="store_device_auth_provider",
1155 )
11161156
11171157 self.device_id_exists_cache.set(key, True)
11181158 return inserted
11621202 self.db_pool.simple_delete_many_txn(
11631203 txn,
11641204 table="device_inbox",
1205 column="device_id",
1206 values=device_ids,
1207 keyvalues={"user_id": user_id},
1208 )
1209
1210 self.db_pool.simple_delete_many_txn(
1211 txn,
1212 table="device_auth_providers",
11651213 column="device_id",
11661214 values=device_ids,
11671215 keyvalues={"user_id": user_id},
15511551 DELETE FROM event_auth
15521552 WHERE event_id IN (
15531553 SELECT event_id FROM events
1554 LEFT JOIN state_events USING (room_id, event_id)
1554 LEFT JOIN state_events AS se USING (room_id, event_id)
15551555 WHERE ? <= stream_ordering AND stream_ordering < ?
1556 AND state_key IS null
1556 AND se.state_key IS null
15571557 )
15581558 """
15591559
1515 from typing import TYPE_CHECKING, Dict, List, Optional, Tuple, Union
1616
1717 import attr
18 from typing_extensions import TypedDict
1819
1920 from synapse.metrics.background_process_metrics import wrap_as_background_process
2021 from synapse.storage._base import SQLBaseStore, db_to_json
3435 {"set_tweak": "sound", "value": "default"},
3536 {"set_tweak": "highlight"},
3637 ]
38
39
40 class BasePushAction(TypedDict):
41 event_id: str
42 actions: List[Union[dict, str]]
43
44
45 class HttpPushAction(BasePushAction):
46 room_id: str
47 stream_ordering: int
48
49
50 class EmailPushAction(HttpPushAction):
51 received_ts: Optional[int]
3752
3853
3954 def _serialize_action(actions, is_highlight):
220235 min_stream_ordering: int,
221236 max_stream_ordering: int,
222237 limit: int = 20,
223 ) -> List[dict]:
238 ) -> List[HttpPushAction]:
224239 """Get a list of the most recent unread push actions for a given user,
225240 within the given stream ordering range. Called by the httppusher.
226241
325340 min_stream_ordering: int,
326341 max_stream_ordering: int,
327342 limit: int = 20,
328 ) -> List[dict]:
343 ) -> List[EmailPushAction]:
329344 """Get a list of the most recent unread push actions for a given user,
330345 within the given stream ordering range. Called by the emailpusher
331346
1414 # limitations under the License.
1515 import itertools
1616 import logging
17 from collections import OrderedDict, namedtuple
17 from collections import OrderedDict
1818 from typing import (
1919 TYPE_CHECKING,
2020 Any,
4040 from synapse.logging.utils import log_function
4141 from synapse.storage._base import db_to_json, make_in_list_sql_clause
4242 from synapse.storage.database import DatabasePool, LoggingTransaction
43 from synapse.storage.databases.main.events_worker import EventCacheEntry
4344 from synapse.storage.databases.main.search import SearchEntry
4445 from synapse.storage.types import Connection
45 from synapse.storage.util.id_generators import MultiWriterIdGenerator
46 from synapse.storage.util.id_generators import AbstractStreamIdGenerator
4647 from synapse.storage.util.sequence import SequenceGenerator
4748 from synapse.types import StateMap, get_domain_from_id
4849 from synapse.util import json_encoder
6162 "",
6263 ["type", "origin_type", "origin_entity"],
6364 )
64
65
66 _EventCacheEntry = namedtuple("_EventCacheEntry", ("event", "redacted_event"))
6765
6866
6967 @attr.s(slots=True)
107105 self._ephemeral_messages_enabled = hs.config.server.enable_ephemeral_messages
108106 self.is_mine_id = hs.is_mine_id
109107
110 # Ideally we'd move these ID gens here, unfortunately some other ID
111 # generators are chained off them so doing so is a bit of a PITA.
112 self._backfill_id_gen: MultiWriterIdGenerator = self.store._backfill_id_gen
113 self._stream_id_gen: MultiWriterIdGenerator = self.store._stream_id_gen
114
115108 # This should only exist on instances that are configured to write
116109 assert (
117110 hs.get_instance_name() in hs.config.worker.writers.events
118111 ), "Can only instantiate EventsStore on master"
119112
113 # Since we have been configured to write, we ought to have id generators,
114 # rather than id trackers.
115 assert isinstance(self.store._backfill_id_gen, AbstractStreamIdGenerator)
116 assert isinstance(self.store._stream_id_gen, AbstractStreamIdGenerator)
117
118 # Ideally we'd move these ID gens here, unfortunately some other ID
119 # generators are chained off them so doing so is a bit of a PITA.
120 self._backfill_id_gen: AbstractStreamIdGenerator = self.store._backfill_id_gen
121 self._stream_id_gen: AbstractStreamIdGenerator = self.store._stream_id_gen
122
120123 async def _persist_events_and_state_updates(
121124 self,
122125 events_and_contexts: List[Tuple[EventBase, EventContext]],
126 *,
123127 current_state_for_room: Dict[str, StateMap[str]],
124128 state_delta_for_room: Dict[str, DeltaState],
125129 new_forward_extremeties: Dict[str, List[str]],
126 backfilled: bool = False,
130 use_negative_stream_ordering: bool = False,
131 inhibit_local_membership_updates: bool = False,
127132 ) -> None:
128133 """Persist a set of events alongside updates to the current state and
129134 forward extremities tables.
136141 room state
137142 new_forward_extremities: Map from room_id to list of event IDs
138143 that are the new forward extremities of the room.
139 backfilled
144 use_negative_stream_ordering: Whether to start stream_ordering on
145 the negative side and decrement. This should be set as True
146 for backfilled events because backfilled events get a negative
147 stream ordering so they don't come down incremental `/sync`.
148 inhibit_local_membership_updates: Stop the local_current_membership
149 from being updated by these events. This should be set to True
150 for backfilled events because backfilled events in the past do
151 not affect the current local state.
140152
141153 Returns:
142154 Resolves when the events have been persisted
158170 #
159171 # Note: Multiple instances of this function cannot be in flight at
160172 # the same time for the same room.
161 if backfilled:
173 if use_negative_stream_ordering:
162174 stream_ordering_manager = self._backfill_id_gen.get_next_mult(
163175 len(events_and_contexts)
164176 )
175187 "persist_events",
176188 self._persist_events_txn,
177189 events_and_contexts=events_and_contexts,
178 backfilled=backfilled,
190 inhibit_local_membership_updates=inhibit_local_membership_updates,
179191 state_delta_for_room=state_delta_for_room,
180192 new_forward_extremeties=new_forward_extremeties,
181193 )
182194 persist_event_counter.inc(len(events_and_contexts))
183195
184 if not backfilled:
196 if stream < 0:
185197 # backfilled events have negative stream orderings, so we don't
186198 # want to set the event_persisted_position to that.
187199 synapse.metrics.event_persisted_position.set(
315327 def _persist_events_txn(
316328 self,
317329 txn: LoggingTransaction,
330 *,
318331 events_and_contexts: List[Tuple[EventBase, EventContext]],
319 backfilled: bool,
332 inhibit_local_membership_updates: bool = False,
320333 state_delta_for_room: Optional[Dict[str, DeltaState]] = None,
321334 new_forward_extremeties: Optional[Dict[str, List[str]]] = None,
322335 ):
329342 Args:
330343 txn
331344 events_and_contexts: events to persist
332 backfilled: True if the events were backfilled
345 inhibit_local_membership_updates: Stop the local_current_membership
346 from being updated by these events. This should be set to True
347 for backfilled events because backfilled events in the past do
348 not affect the current local state.
333349 delete_existing True to purge existing table rows for the events
334350 from the database. This is useful when retrying due to
335351 IntegrityError.
362378 events_and_contexts
363379 )
364380
365 self._update_room_depths_txn(
366 txn, events_and_contexts=events_and_contexts, backfilled=backfilled
367 )
381 self._update_room_depths_txn(txn, events_and_contexts=events_and_contexts)
368382
369383 # _update_outliers_txn filters out any events which have already been
370384 # persisted, and returns the filtered list.
397411 txn,
398412 events_and_contexts=events_and_contexts,
399413 all_events_and_contexts=all_events_and_contexts,
400 backfilled=backfilled,
414 inhibit_local_membership_updates=inhibit_local_membership_updates,
401415 )
402416
403417 # We call this last as it assumes we've inserted the events into
560574 # fetch their auth event info.
561575 while missing_auth_chains:
562576 sql = """
563 SELECT event_id, events.type, state_key, chain_id, sequence_number
577 SELECT event_id, events.type, se.state_key, chain_id, sequence_number
564578 FROM events
565 INNER JOIN state_events USING (event_id)
579 INNER JOIN state_events AS se USING (event_id)
566580 LEFT JOIN event_auth_chains USING (event_id)
567581 WHERE
568582 """
11991213 self,
12001214 txn,
12011215 events_and_contexts: List[Tuple[EventBase, EventContext]],
1202 backfilled: bool,
12031216 ):
12041217 """Update min_depth for each room
12051218
12071220 txn (twisted.enterprise.adbapi.Connection): db connection
12081221 events_and_contexts (list[(EventBase, EventContext)]): events
12091222 we are persisting
1210 backfilled (bool): True if the events were backfilled
12111223 """
12121224 depth_updates: Dict[str, int] = {}
12131225 for event, context in events_and_contexts:
12141226 # Remove the any existing cache entries for the event_ids
12151227 txn.call_after(self.store._invalidate_get_event_cache, event.event_id)
1216 if not backfilled:
1228 # Then update the `stream_ordering` position to mark the latest
1229 # event as the front of the room. This should not be done for
1230 # backfilled events because backfilled events have negative
1231 # stream_ordering and happened in the past so we know that we don't
1232 # need to update the stream_ordering tip/front for the room.
1233 assert event.internal_metadata.stream_ordering is not None
1234 if event.internal_metadata.stream_ordering >= 0:
12171235 txn.call_after(
12181236 self.store._events_stream_cache.entity_has_changed,
12191237 event.room_id,
14261444 return [ec for ec in events_and_contexts if ec[0] not in to_remove]
14271445
14281446 def _update_metadata_tables_txn(
1429 self, txn, events_and_contexts, all_events_and_contexts, backfilled
1447 self,
1448 txn,
1449 *,
1450 events_and_contexts,
1451 all_events_and_contexts,
1452 inhibit_local_membership_updates: bool = False,
14301453 ):
14311454 """Update all the miscellaneous tables for new events
14321455
14381461 events that we were going to persist. This includes events
14391462 we've already persisted, etc, that wouldn't appear in
14401463 events_and_context.
1441 backfilled (bool): True if the events were backfilled
1464 inhibit_local_membership_updates: Stop the local_current_membership
1465 from being updated by these events. This should be set to True
1466 for backfilled events because backfilled events in the past do
1467 not affect the current local state.
14421468 """
14431469
14441470 # Insert all the push actions into the event_push_actions table.
15121538 for event, _ in events_and_contexts
15131539 if event.type == EventTypes.Member
15141540 ],
1515 backfilled=backfilled,
1541 inhibit_local_membership_updates=inhibit_local_membership_updates,
15161542 )
15171543
15181544 # Insert event_reference_hashes table.
15521578 for row in rows:
15531579 event = ev_map[row["event_id"]]
15541580 if not row["rejects"] and not row["redacts"]:
1555 to_prefill.append(_EventCacheEntry(event=event, redacted_event=None))
1581 to_prefill.append(EventCacheEntry(event=event, redacted_event=None))
15561582
15571583 def prefill():
15581584 for cache_entry in to_prefill:
1559 self.store._get_event_cache.set((cache_entry[0].event_id,), cache_entry)
1585 self.store._get_event_cache.set(
1586 (cache_entry.event.event_id,), cache_entry
1587 )
15601588
15611589 txn.call_after(prefill)
15621590
16371665 txn, table="event_reference_hashes", values=vals
16381666 )
16391667
1640 def _store_room_members_txn(self, txn, events, backfilled):
1641 """Store a room member in the database."""
1668 def _store_room_members_txn(
1669 self, txn, events, *, inhibit_local_membership_updates: bool = False
1670 ):
1671 """
1672 Store a room member in the database.
1673 Args:
1674 txn: The transaction to use.
1675 events: List of events to store.
1676 inhibit_local_membership_updates: Stop the local_current_membership
1677 from being updated by these events. This should be set to True
1678 for backfilled events because backfilled events in the past do
1679 not affect the current local state.
1680 """
16421681
16431682 def non_null_str_or_none(val: Any) -> Optional[str]:
16441683 return val if isinstance(val, str) and "\u0000" not in val else None
16811720 # band membership", like a remote invite or a rejection of a remote invite.
16821721 if (
16831722 self.is_mine_id(event.state_key)
1684 and not backfilled
1723 and not inhibit_local_membership_updates
16851724 and event.internal_metadata.is_outlier()
16861725 and event.internal_metadata.is_out_of_band_membership()
16871726 ):
1414 import logging
1515 import threading
1616 from typing import (
17 TYPE_CHECKING,
18 Any,
1719 Collection,
1820 Container,
1921 Dict,
2022 Iterable,
2123 List,
24 NoReturn,
2225 Optional,
2326 Set,
2427 Tuple,
28 cast,
2529 overload,
2630 )
2731
3741 from synapse.api.room_versions import (
3842 KNOWN_ROOM_VERSIONS,
3943 EventFormatVersions,
44 RoomVersion,
4045 RoomVersions,
4146 )
4247 from synapse.events import EventBase, make_event_from_dict
5560 from synapse.replication.tcp.streams import BackfillStream
5661 from synapse.replication.tcp.streams.events import EventsStream
5762 from synapse.storage._base import SQLBaseStore, db_to_json, make_in_list_sql_clause
58 from synapse.storage.database import DatabasePool, LoggingTransaction
63 from synapse.storage.database import (
64 DatabasePool,
65 LoggingDatabaseConnection,
66 LoggingTransaction,
67 )
5968 from synapse.storage.engines import PostgresEngine
60 from synapse.storage.types import Connection
61 from synapse.storage.util.id_generators import MultiWriterIdGenerator, StreamIdGenerator
69 from synapse.storage.types import Cursor
70 from synapse.storage.util.id_generators import (
71 AbstractStreamIdTracker,
72 MultiWriterIdGenerator,
73 StreamIdGenerator,
74 )
6275 from synapse.storage.util.sequence import build_sequence_generator
6376 from synapse.types import JsonDict, get_domain_from_id
6477 from synapse.util import unwrapFirstError
6881 from synapse.util.iterutils import batch_iter
6982 from synapse.util.metrics import Measure
7083
84 if TYPE_CHECKING:
85 from synapse.server import HomeServer
86
7187 logger = logging.getLogger(__name__)
7288
7389
74 # These values are used in the `enqueus_event` and `_do_fetch` methods to
90 # These values are used in the `enqueue_event` and `_fetch_loop` methods to
7591 # control how we batch/bulk fetch events from the database.
7692 # The values are plucked out of thing air to make initial sync run faster
7793 # on jki.re
88104
89105
90106 @attr.s(slots=True, auto_attribs=True)
91 class _EventCacheEntry:
107 class EventCacheEntry:
92108 event: EventBase
93109 redacted_event: Optional[EventBase]
94110
128144 json: str
129145 internal_metadata: str
130146 format_version: Optional[int]
131 room_version_id: Optional[int]
147 room_version_id: Optional[str]
132148 rejected_reason: Optional[str]
133149 redactions: List[str]
134150 outlier: bool
152168 # options controlling this.
153169 USE_DEDICATED_DB_THREADS_FOR_EVENT_FETCHING = True
154170
155 def __init__(self, database: DatabasePool, db_conn, hs):
171 def __init__(
172 self,
173 database: DatabasePool,
174 db_conn: LoggingDatabaseConnection,
175 hs: "HomeServer",
176 ):
156177 super().__init__(database, db_conn, hs)
157178
179 self._stream_id_gen: AbstractStreamIdTracker
180 self._backfill_id_gen: AbstractStreamIdTracker
158181 if isinstance(database.engine, PostgresEngine):
159182 # If we're using Postgres than we can use `MultiWriterIdGenerator`
160183 # regardless of whether this process writes to the streams or not.
213236 5 * 60 * 1000,
214237 )
215238
216 self._get_event_cache = LruCache(
239 self._get_event_cache: LruCache[Tuple[str], EventCacheEntry] = LruCache(
217240 cache_name="*getEvent*",
218241 max_size=hs.config.caches.event_cache_size,
219242 )
222245 # ID to cache entry. Note that the returned dict may not have the
223246 # requested event in it if the event isn't in the DB.
224247 self._current_event_fetches: Dict[
225 str, ObservableDeferred[Dict[str, _EventCacheEntry]]
248 str, ObservableDeferred[Dict[str, EventCacheEntry]]
226249 ] = {}
227250
228251 self._event_fetch_lock = threading.Condition()
229 self._event_fetch_list = []
252 self._event_fetch_list: List[
253 Tuple[Iterable[str], "defer.Deferred[Dict[str, _EventRow]]"]
254 ] = []
230255 self._event_fetch_ongoing = 0
231256 event_fetch_ongoing_gauge.set(self._event_fetch_ongoing)
232257
233258 # We define this sequence here so that it can be referenced from both
234259 # the DataStore and PersistEventStore.
235 def get_chain_id_txn(txn):
260 def get_chain_id_txn(txn: Cursor) -> int:
236261 txn.execute("SELECT COALESCE(max(chain_id), 0) FROM event_auth_chains")
237 return txn.fetchone()[0]
262 return cast(Tuple[int], txn.fetchone())[0]
238263
239264 self.event_chain_id_gen = build_sequence_generator(
240265 db_conn,
245270 id_column="chain_id",
246271 )
247272
248 def process_replication_rows(self, stream_name, instance_name, token, rows):
273 def process_replication_rows(
274 self,
275 stream_name: str,
276 instance_name: str,
277 token: int,
278 rows: Iterable[Any],
279 ) -> None:
249280 if stream_name == EventsStream.NAME:
250281 self._stream_id_gen.advance(instance_name, token)
251282 elif stream_name == BackfillStream.NAME:
279310 self,
280311 event_id: str,
281312 redact_behaviour: EventRedactBehaviour = EventRedactBehaviour.REDACT,
282 get_prev_content: bool = False,
283 allow_rejected: bool = False,
284 allow_none: Literal[False] = False,
285 check_room_id: Optional[str] = None,
313 get_prev_content: bool = ...,
314 allow_rejected: bool = ...,
315 allow_none: Literal[False] = ...,
316 check_room_id: Optional[str] = ...,
286317 ) -> EventBase:
287318 ...
288319
291322 self,
292323 event_id: str,
293324 redact_behaviour: EventRedactBehaviour = EventRedactBehaviour.REDACT,
294 get_prev_content: bool = False,
295 allow_rejected: bool = False,
296 allow_none: Literal[True] = False,
297 check_room_id: Optional[str] = None,
325 get_prev_content: bool = ...,
326 allow_rejected: bool = ...,
327 allow_none: Literal[True] = ...,
328 check_room_id: Optional[str] = ...,
298329 ) -> Optional[EventBase]:
299330 ...
300331
356387
357388 async def get_events(
358389 self,
359 event_ids: Iterable[str],
390 event_ids: Collection[str],
360391 redact_behaviour: EventRedactBehaviour = EventRedactBehaviour.REDACT,
361392 get_prev_content: bool = False,
362393 allow_rejected: bool = False,
543574
544575 async def _get_events_from_cache_or_db(
545576 self, event_ids: Iterable[str], allow_rejected: bool = False
546 ) -> Dict[str, _EventCacheEntry]:
577 ) -> Dict[str, EventCacheEntry]:
547578 """Fetch a bunch of events from the cache or the database.
548579
549580 If events are pulled from the database, they will be cached for future lookups.
577608 # same dict into itself N times).
578609 already_fetching_ids: Set[str] = set()
579610 already_fetching_deferreds: Set[
580 ObservableDeferred[Dict[str, _EventCacheEntry]]
611 ObservableDeferred[Dict[str, EventCacheEntry]]
581612 ] = set()
582613
583614 for event_id in missing_events_ids:
600631 # function returning more events than requested, but that can happen
601632 # already due to `_get_events_from_db`).
602633 fetching_deferred: ObservableDeferred[
603 Dict[str, _EventCacheEntry]
604 ] = ObservableDeferred(defer.Deferred())
634 Dict[str, EventCacheEntry]
635 ] = ObservableDeferred(defer.Deferred(), consumeErrors=True)
605636 for event_id in missing_events_ids:
606637 self._current_event_fetches[event_id] = fetching_deferred
607638
657688
658689 return event_entry_map
659690
660 def _invalidate_get_event_cache(self, event_id):
691 def _invalidate_get_event_cache(self, event_id: str) -> None:
661692 self._get_event_cache.invalidate((event_id,))
662693
663694 def _get_events_from_cache(
664695 self, events: Iterable[str], update_metrics: bool = True
665 ) -> Dict[str, _EventCacheEntry]:
696 ) -> Dict[str, EventCacheEntry]:
666697 """Fetch events from the caches.
667698
668699 May return rejected events.
735766 for e in state_to_include.values()
736767 ]
737768
738 def _do_fetch(self, conn: Connection) -> None:
769 def _maybe_start_fetch_thread(self) -> None:
770 """Starts an event fetch thread if we are not yet at the maximum number."""
771 with self._event_fetch_lock:
772 if (
773 self._event_fetch_list
774 and self._event_fetch_ongoing < EVENT_QUEUE_THREADS
775 ):
776 self._event_fetch_ongoing += 1
777 event_fetch_ongoing_gauge.set(self._event_fetch_ongoing)
778 # `_event_fetch_ongoing` is decremented in `_fetch_thread`.
779 should_start = True
780 else:
781 should_start = False
782
783 if should_start:
784 run_as_background_process("fetch_events", self._fetch_thread)
785
786 async def _fetch_thread(self) -> None:
787 """Services requests for events from `_event_fetch_list`."""
788 exc = None
789 try:
790 await self.db_pool.runWithConnection(self._fetch_loop)
791 except BaseException as e:
792 exc = e
793 raise
794 finally:
795 should_restart = False
796 event_fetches_to_fail = []
797 with self._event_fetch_lock:
798 self._event_fetch_ongoing -= 1
799 event_fetch_ongoing_gauge.set(self._event_fetch_ongoing)
800
801 # There may still be work remaining in `_event_fetch_list` if we
802 # failed, or it was added in between us deciding to exit and
803 # decrementing `_event_fetch_ongoing`.
804 if self._event_fetch_list:
805 if exc is None:
806 # We decided to exit, but then some more work was added
807 # before `_event_fetch_ongoing` was decremented.
808 # If a new event fetch thread was not started, we should
809 # restart ourselves since the remaining event fetch threads
810 # may take a while to get around to the new work.
811 #
812 # Unfortunately it is not possible to tell whether a new
813 # event fetch thread was started, so we restart
814 # unconditionally. If we are unlucky, we will end up with
815 # an idle fetch thread, but it will time out after
816 # `EVENT_QUEUE_ITERATIONS * EVENT_QUEUE_TIMEOUT_S` seconds
817 # in any case.
818 #
819 # Note that multiple fetch threads may run down this path at
820 # the same time.
821 should_restart = True
822 elif isinstance(exc, Exception):
823 if self._event_fetch_ongoing == 0:
824 # We were the last remaining fetcher and failed.
825 # Fail any outstanding fetches since no one else will
826 # handle them.
827 event_fetches_to_fail = self._event_fetch_list
828 self._event_fetch_list = []
829 else:
830 # We weren't the last remaining fetcher, so another
831 # fetcher will pick up the work. This will either happen
832 # after their existing work, however long that takes,
833 # or after at most `EVENT_QUEUE_TIMEOUT_S` seconds if
834 # they are idle.
835 pass
836 else:
837 # The exception is a `SystemExit`, `KeyboardInterrupt` or
838 # `GeneratorExit`. Don't try to do anything clever here.
839 pass
840
841 if should_restart:
842 # We exited cleanly but noticed more work.
843 self._maybe_start_fetch_thread()
844
845 if event_fetches_to_fail:
846 # We were the last remaining fetcher and failed.
847 # Fail any outstanding fetches since no one else will handle them.
848 assert exc is not None
849 with PreserveLoggingContext():
850 for _, deferred in event_fetches_to_fail:
851 deferred.errback(exc)
852
853 def _fetch_loop(self, conn: LoggingDatabaseConnection) -> None:
739854 """Takes a database connection and waits for requests for events from
740855 the _event_fetch_list queue.
741856 """
742 try:
743 i = 0
744 while True:
745 with self._event_fetch_lock:
746 event_list = self._event_fetch_list
747 self._event_fetch_list = []
748
749 if not event_list:
750 single_threaded = self.database_engine.single_threaded
751 if (
752 not self.USE_DEDICATED_DB_THREADS_FOR_EVENT_FETCHING
753 or single_threaded
754 or i > EVENT_QUEUE_ITERATIONS
755 ):
756 break
757 else:
758 self._event_fetch_lock.wait(EVENT_QUEUE_TIMEOUT_S)
759 i += 1
760 continue
761 i = 0
762
763 self._fetch_event_list(conn, event_list)
764 finally:
765 self._event_fetch_ongoing -= 1
766 event_fetch_ongoing_gauge.set(self._event_fetch_ongoing)
857 i = 0
858 while True:
859 with self._event_fetch_lock:
860 event_list = self._event_fetch_list
861 self._event_fetch_list = []
862
863 if not event_list:
864 # There are no requests waiting. If we haven't yet reached the
865 # maximum iteration limit, wait for some more requests to turn up.
866 # Otherwise, bail out.
867 single_threaded = self.database_engine.single_threaded
868 if (
869 not self.USE_DEDICATED_DB_THREADS_FOR_EVENT_FETCHING
870 or single_threaded
871 or i > EVENT_QUEUE_ITERATIONS
872 ):
873 return
874
875 self._event_fetch_lock.wait(EVENT_QUEUE_TIMEOUT_S)
876 i += 1
877 continue
878 i = 0
879
880 self._fetch_event_list(conn, event_list)
767881
768882 def _fetch_event_list(
769 self, conn: Connection, event_list: List[Tuple[List[str], defer.Deferred]]
883 self,
884 conn: LoggingDatabaseConnection,
885 event_list: List[Tuple[Iterable[str], "defer.Deferred[Dict[str, _EventRow]]"]],
770886 ) -> None:
771887 """Handle a load of requests from the _event_fetch_list queue
772888
793909 )
794910
795911 # We only want to resolve deferreds from the main thread
796 def fire():
912 def fire() -> None:
797913 for _, d in event_list:
798914 d.callback(row_dict)
799915
803919 logger.exception("do_fetch")
804920
805921 # We only want to resolve deferreds from the main thread
806 def fire(evs, exc):
807 for _, d in evs:
808 if not d.called:
809 with PreserveLoggingContext():
810 d.errback(exc)
922 def fire_errback(exc: Exception) -> None:
923 for _, d in event_list:
924 d.errback(exc)
811925
812926 with PreserveLoggingContext():
813 self.hs.get_reactor().callFromThread(fire, event_list, e)
927 self.hs.get_reactor().callFromThread(fire_errback, e)
814928
815929 async def _get_events_from_db(
816 self, event_ids: Iterable[str]
817 ) -> Dict[str, _EventCacheEntry]:
930 self, event_ids: Collection[str]
931 ) -> Dict[str, EventCacheEntry]:
818932 """Fetch a bunch of events from the database.
819933
820934 May return rejected events.
830944 map from event id to result. May return extra events which
831945 weren't asked for.
832946 """
833 fetched_events = {}
947 fetched_event_ids: Set[str] = set()
948 fetched_events: Dict[str, _EventRow] = {}
834949 events_to_fetch = event_ids
835950
836951 while events_to_fetch:
837952 row_map = await self._enqueue_events(events_to_fetch)
838953
839954 # we need to recursively fetch any redactions of those events
840 redaction_ids = set()
955 redaction_ids: Set[str] = set()
841956 for event_id in events_to_fetch:
842957 row = row_map.get(event_id)
843 fetched_events[event_id] = row
958 fetched_event_ids.add(event_id)
844959 if row:
960 fetched_events[event_id] = row
845961 redaction_ids.update(row.redactions)
846962
847 events_to_fetch = redaction_ids.difference(fetched_events.keys())
963 events_to_fetch = redaction_ids.difference(fetched_event_ids)
848964 if events_to_fetch:
849965 logger.debug("Also fetching redaction events %s", events_to_fetch)
850966
851967 # build a map from event_id to EventBase
852 event_map = {}
968 event_map: Dict[str, EventBase] = {}
853969 for event_id, row in fetched_events.items():
854 if not row:
855 continue
856970 assert row.event_id == event_id
857971
858972 rejected_reason = row.rejected_reason
880994
881995 room_version_id = row.room_version_id
882996
997 room_version: Optional[RoomVersion]
883998 if not room_version_id:
884999 # this should only happen for out-of-band membership events which
8851000 # arrived before #6983 landed. For all other events, we should have
9501065
9511066 # finally, we can decide whether each one needs redacting, and build
9521067 # the cache entries.
953 result_map = {}
1068 result_map: Dict[str, EventCacheEntry] = {}
9541069 for event_id, original_ev in event_map.items():
9551070 redactions = fetched_events[event_id].redactions
9561071 redacted_event = self._maybe_redact_event_row(
9571072 original_ev, redactions, event_map
9581073 )
9591074
960 cache_entry = _EventCacheEntry(
1075 cache_entry = EventCacheEntry(
9611076 event=original_ev, redacted_event=redacted_event
9621077 )
9631078
9661081
9671082 return result_map
9681083
969 async def _enqueue_events(self, events: Iterable[str]) -> Dict[str, _EventRow]:
1084 async def _enqueue_events(self, events: Collection[str]) -> Dict[str, _EventRow]:
9701085 """Fetches events from the database using the _event_fetch_list. This
9711086 allows batch and bulk fetching of events - it allows us to fetch events
9721087 without having to create a new transaction for each request for events.
9791094 that weren't requested.
9801095 """
9811096
982 events_d = defer.Deferred()
1097 events_d: "defer.Deferred[Dict[str, _EventRow]]" = defer.Deferred()
9831098 with self._event_fetch_lock:
9841099 self._event_fetch_list.append((events, events_d))
985
9861100 self._event_fetch_lock.notify()
9871101
988 if self._event_fetch_ongoing < EVENT_QUEUE_THREADS:
989 self._event_fetch_ongoing += 1
990 event_fetch_ongoing_gauge.set(self._event_fetch_ongoing)
991 should_start = True
992 else:
993 should_start = False
994
995 if should_start:
996 run_as_background_process(
997 "fetch_events", self.db_pool.runWithConnection, self._do_fetch
998 )
1102 self._maybe_start_fetch_thread()
9991103
10001104 logger.debug("Loading %d events: %s", len(events), events)
10011105 with PreserveLoggingContext():
11451249 # no valid redaction found for this event
11461250 return None
11471251
1148 async def have_events_in_timeline(self, event_ids):
1252 async def have_events_in_timeline(self, event_ids: Iterable[str]) -> Set[str]:
11491253 """Given a list of event ids, check if we have already processed and
11501254 stored them as non outliers.
11511255 """
11741278 event_ids: events we are looking for
11751279
11761280 Returns:
1177 set[str]: The events we have already seen.
1281 The set of events we have already seen.
11781282 """
11791283 res = await self._have_seen_events_dict(
11801284 (room_id, event_id) for event_id in event_ids
11971301 }
11981302 results = {x: True for x in cache_results}
11991303
1200 def have_seen_events_txn(txn, chunk: Tuple[Tuple[str, str], ...]):
1304 def have_seen_events_txn(
1305 txn: LoggingTransaction, chunk: Tuple[Tuple[str, str], ...]
1306 ) -> None:
12011307 # we deliberately do *not* query the database for room_id, to make the
12021308 # query an index-only lookup on `events_event_id_key`.
12031309 #
12231329 return results
12241330
12251331 @cached(max_entries=100000, tree=True)
1226 async def have_seen_event(self, room_id: str, event_id: str):
1332 async def have_seen_event(self, room_id: str, event_id: str) -> NoReturn:
12271333 # this only exists for the benefit of the @cachedList descriptor on
12281334 # _have_seen_events_dict
12291335 raise NotImplementedError()
12301336
1231 def _get_current_state_event_counts_txn(self, txn, room_id):
1337 def _get_current_state_event_counts_txn(
1338 self, txn: LoggingTransaction, room_id: str
1339 ) -> int:
12321340 """
12331341 See get_current_state_event_counts.
12341342 """
12531361 room_id,
12541362 )
12551363
1256 async def get_room_complexity(self, room_id):
1364 async def get_room_complexity(self, room_id: str) -> Dict[str, float]:
12571365 """
12581366 Get a rough approximation of the complexity of the room. This is used by
12591367 remote servers to decide whether they wish to join the room or not.
12611369 more resources.
12621370
12631371 Args:
1264 room_id (str)
1265
1266 Returns:
1267 dict[str:int] of complexity version to complexity.
1372 room_id: The room ID to query.
1373
1374 Returns:
1375 dict[str:float] of complexity version to complexity.
12681376 """
12691377 state_events = await self.get_current_state_event_counts(room_id)
12701378
12741382
12751383 return {"v1": complexity_v1}
12761384
1277 def get_current_events_token(self):
1385 def get_current_events_token(self) -> int:
12781386 """The current maximum token that events have reached"""
12791387 return self._stream_id_gen.get_current_token()
12801388
12811389 async def get_all_new_forward_event_rows(
12821390 self, instance_name: str, last_id: int, current_id: int, limit: int
1283 ) -> List[Tuple]:
1391 ) -> List[Tuple[int, str, str, str, str, str, str, str, str]]:
12841392 """Returns new events, for the Events replication stream
12851393
12861394 Args:
12941402 EventsStreamRow.
12951403 """
12961404
1297 def get_all_new_forward_event_rows(txn):
1405 def get_all_new_forward_event_rows(
1406 txn: LoggingTransaction,
1407 ) -> List[Tuple[int, str, str, str, str, str, str, str, str]]:
12981408 sql = (
12991409 "SELECT e.stream_ordering, e.event_id, e.room_id, e.type,"
1300 " state_key, redacts, relates_to_id, membership, rejections.reason IS NOT NULL"
1410 " se.state_key, redacts, relates_to_id, membership, rejections.reason IS NOT NULL"
13011411 " FROM events AS e"
13021412 " LEFT JOIN redactions USING (event_id)"
1303 " LEFT JOIN state_events USING (event_id)"
1413 " LEFT JOIN state_events AS se USING (event_id)"
13041414 " LEFT JOIN event_relations USING (event_id)"
13051415 " LEFT JOIN room_memberships USING (event_id)"
13061416 " LEFT JOIN rejections USING (event_id)"
13101420 " LIMIT ?"
13111421 )
13121422 txn.execute(sql, (last_id, current_id, instance_name, limit))
1313 return txn.fetchall()
1423 return cast(
1424 List[Tuple[int, str, str, str, str, str, str, str, str]], txn.fetchall()
1425 )
13141426
13151427 return await self.db_pool.runInteraction(
13161428 "get_all_new_forward_event_rows", get_all_new_forward_event_rows
13181430
13191431 async def get_ex_outlier_stream_rows(
13201432 self, instance_name: str, last_id: int, current_id: int
1321 ) -> List[Tuple]:
1433 ) -> List[Tuple[int, str, str, str, str, str, str, str, str]]:
13221434 """Returns de-outliered events, for the Events replication stream
13231435
13241436 Args:
13311443 EventsStreamRow.
13321444 """
13331445
1334 def get_ex_outlier_stream_rows_txn(txn):
1446 def get_ex_outlier_stream_rows_txn(
1447 txn: LoggingTransaction,
1448 ) -> List[Tuple[int, str, str, str, str, str, str, str, str]]:
13351449 sql = (
13361450 "SELECT event_stream_ordering, e.event_id, e.room_id, e.type,"
1337 " state_key, redacts, relates_to_id, membership, rejections.reason IS NOT NULL"
1451 " se.state_key, redacts, relates_to_id, membership, rejections.reason IS NOT NULL"
13381452 " FROM events AS e"
13391453 " INNER JOIN ex_outlier_stream AS out USING (event_id)"
13401454 " LEFT JOIN redactions USING (event_id)"
1341 " LEFT JOIN state_events USING (event_id)"
1455 " LEFT JOIN state_events AS se USING (event_id)"
13421456 " LEFT JOIN event_relations USING (event_id)"
13431457 " LEFT JOIN room_memberships USING (event_id)"
13441458 " LEFT JOIN rejections USING (event_id)"
13491463 )
13501464
13511465 txn.execute(sql, (last_id, current_id, instance_name))
1352 return txn.fetchall()
1466 return cast(
1467 List[Tuple[int, str, str, str, str, str, str, str, str]], txn.fetchall()
1468 )
13531469
13541470 return await self.db_pool.runInteraction(
13551471 "get_ex_outlier_stream_rows", get_ex_outlier_stream_rows_txn
13571473
13581474 async def get_all_new_backfill_event_rows(
13591475 self, instance_name: str, last_id: int, current_id: int, limit: int
1360 ) -> Tuple[List[Tuple[int, list]], int, bool]:
1476 ) -> Tuple[List[Tuple[int, Tuple[str, str, str, str, str, str]]], int, bool]:
13611477 """Get updates for backfill replication stream, including all new
13621478 backfilled events and events that have gone from being outliers to not.
13631479
13851501 if last_id == current_id:
13861502 return [], current_id, False
13871503
1388 def get_all_new_backfill_event_rows(txn):
1504 def get_all_new_backfill_event_rows(
1505 txn: LoggingTransaction,
1506 ) -> Tuple[List[Tuple[int, Tuple[str, str, str, str, str, str]]], int, bool]:
13891507 sql = (
13901508 "SELECT -e.stream_ordering, e.event_id, e.room_id, e.type,"
1391 " state_key, redacts, relates_to_id"
1509 " se.state_key, redacts, relates_to_id"
13921510 " FROM events AS e"
13931511 " LEFT JOIN redactions USING (event_id)"
1394 " LEFT JOIN state_events USING (event_id)"
1512 " LEFT JOIN state_events AS se USING (event_id)"
13951513 " LEFT JOIN event_relations USING (event_id)"
13961514 " WHERE ? > stream_ordering AND stream_ordering >= ?"
13971515 " AND instance_name = ?"
13991517 " LIMIT ?"
14001518 )
14011519 txn.execute(sql, (-last_id, -current_id, instance_name, limit))
1402 new_event_updates = [(row[0], row[1:]) for row in txn]
1520 new_event_updates: List[
1521 Tuple[int, Tuple[str, str, str, str, str, str]]
1522 ] = []
1523 row: Tuple[int, str, str, str, str, str, str]
1524 # Type safety: iterating over `txn` yields `Tuple`, i.e.
1525 # `Tuple[Any, ...]` of arbitrary length. Mypy detects assigning a
1526 # variadic tuple to a fixed length tuple and flags it up as an error.
1527 for row in txn: # type: ignore[assignment]
1528 new_event_updates.append((row[0], row[1:]))
14031529
14041530 limited = False
14051531 if len(new_event_updates) == limit:
14101536
14111537 sql = (
14121538 "SELECT -event_stream_ordering, e.event_id, e.room_id, e.type,"
1413 " state_key, redacts, relates_to_id"
1539 " se.state_key, redacts, relates_to_id"
14141540 " FROM events AS e"
14151541 " INNER JOIN ex_outlier_stream AS out USING (event_id)"
14161542 " LEFT JOIN redactions USING (event_id)"
1417 " LEFT JOIN state_events USING (event_id)"
1543 " LEFT JOIN state_events AS se USING (event_id)"
14181544 " LEFT JOIN event_relations USING (event_id)"
14191545 " WHERE ? > event_stream_ordering"
14201546 " AND event_stream_ordering >= ?"
14221548 " ORDER BY event_stream_ordering DESC"
14231549 )
14241550 txn.execute(sql, (-last_id, -upper_bound, instance_name))
1425 new_event_updates.extend((row[0], row[1:]) for row in txn)
1551 # Type safety: iterating over `txn` yields `Tuple`, i.e.
1552 # `Tuple[Any, ...]` of arbitrary length. Mypy detects assigning a
1553 # variadic tuple to a fixed length tuple and flags it up as an error.
1554 for row in txn: # type: ignore[assignment]
1555 new_event_updates.append((row[0], row[1:]))
14261556
14271557 if len(new_event_updates) >= limit:
14281558 upper_bound = new_event_updates[-1][0]
14361566
14371567 async def get_all_updated_current_state_deltas(
14381568 self, instance_name: str, from_token: int, to_token: int, target_row_count: int
1439 ) -> Tuple[List[Tuple], int, bool]:
1569 ) -> Tuple[List[Tuple[int, str, str, str, str]], int, bool]:
14401570 """Fetch updates from current_state_delta_stream
14411571
14421572 Args:
14561586 * `limited` is whether there are more updates to fetch.
14571587 """
14581588
1459 def get_all_updated_current_state_deltas_txn(txn):
1589 def get_all_updated_current_state_deltas_txn(
1590 txn: LoggingTransaction,
1591 ) -> List[Tuple[int, str, str, str, str]]:
14601592 sql = """
14611593 SELECT stream_id, room_id, type, state_key, event_id
14621594 FROM current_state_delta_stream
14651597 ORDER BY stream_id ASC LIMIT ?
14661598 """
14671599 txn.execute(sql, (from_token, to_token, instance_name, target_row_count))
1468 return txn.fetchall()
1469
1470 def get_deltas_for_stream_id_txn(txn, stream_id):
1600 return cast(List[Tuple[int, str, str, str, str]], txn.fetchall())
1601
1602 def get_deltas_for_stream_id_txn(
1603 txn: LoggingTransaction, stream_id: int
1604 ) -> List[Tuple[int, str, str, str, str]]:
14711605 sql = """
14721606 SELECT stream_id, room_id, type, state_key, event_id
14731607 FROM current_state_delta_stream
14741608 WHERE stream_id = ?
14751609 """
14761610 txn.execute(sql, [stream_id])
1477 return txn.fetchall()
1611 return cast(List[Tuple[int, str, str, str, str]], txn.fetchall())
14781612
14791613 # we need to make sure that, for every stream id in the results, we get *all*
14801614 # the rows with that stream id.
14811615
1482 rows: List[Tuple] = await self.db_pool.runInteraction(
1616 rows: List[Tuple[int, str, str, str, str]] = await self.db_pool.runInteraction(
14831617 "get_all_updated_current_state_deltas",
14841618 get_all_updated_current_state_deltas_txn,
14851619 )
15081642
15091643 return rows, to_token, True
15101644
1511 async def is_event_after(self, event_id1, event_id2):
1645 async def is_event_after(self, event_id1: str, event_id2: str) -> bool:
15121646 """Returns True if event_id1 is after event_id2 in the stream"""
15131647 to_1, so_1 = await self.get_event_ordering(event_id1)
15141648 to_2, so_2 = await self.get_event_ordering(event_id2)
15151649 return (to_1, so_1) > (to_2, so_2)
15161650
15171651 @cached(max_entries=5000)
1518 async def get_event_ordering(self, event_id):
1652 async def get_event_ordering(self, event_id: str) -> Tuple[int, int]:
15191653 res = await self.db_pool.simple_select_one(
15201654 table="events",
15211655 retcols=["topological_ordering", "stream_ordering"],
15381672 None otherwise.
15391673 """
15401674
1541 def get_next_event_to_expire_txn(txn):
1675 def get_next_event_to_expire_txn(
1676 txn: LoggingTransaction,
1677 ) -> Optional[Tuple[str, int]]:
15421678 txn.execute(
15431679 """
15441680 SELECT event_id, expiry_ts FROM event_expiry
15461682 """
15471683 )
15481684
1549 return txn.fetchone()
1685 return cast(Optional[Tuple[str, int]], txn.fetchone())
15501686
15511687 return await self.db_pool.runInteraction(
15521688 desc="get_next_event_to_expire", func=get_next_event_to_expire_txn
16101746 return mapping
16111747
16121748 @wrap_as_background_process("_cleanup_old_transaction_ids")
1613 async def _cleanup_old_transaction_ids(self):
1749 async def _cleanup_old_transaction_ids(self) -> None:
16141750 """Cleans out transaction id mappings older than 24hrs."""
16151751
1616 def _cleanup_old_transaction_ids_txn(txn):
1752 def _cleanup_old_transaction_ids_txn(txn: LoggingTransaction) -> None:
16171753 sql = """
16181754 DELETE FROM event_txn_id
16191755 WHERE inserted_ts < ?
16251761 "_cleanup_old_transaction_ids",
16261762 _cleanup_old_transaction_ids_txn,
16271763 )
1764
1765 async def is_event_next_to_backward_gap(self, event: EventBase) -> bool:
1766 """Check if the given event is next to a backward gap of missing events.
1767 <latest messages> A(False)--->B(False)--->C(True)---> <gap, unknown events> <oldest messages>
1768
1769 Args:
1770 room_id: room where the event lives
1771 event_id: event to check
1772
1773 Returns:
1774 Boolean indicating whether it's an extremity
1775 """
1776
1777 def is_event_next_to_backward_gap_txn(txn: LoggingTransaction) -> bool:
1778 # If the event in question has any of its prev_events listed as a
1779 # backward extremity, it's next to a gap.
1780 #
1781 # We can't just check the backward edges in `event_edges` because
1782 # when we persist events, we will also record the prev_events as
1783 # edges to the event in question regardless of whether we have those
1784 # prev_events yet. We need to check whether those prev_events are
1785 # backward extremities, also known as gaps, that need to be
1786 # backfilled.
1787 backward_extremity_query = """
1788 SELECT 1 FROM event_backward_extremities
1789 WHERE
1790 room_id = ?
1791 AND %s
1792 LIMIT 1
1793 """
1794
1795 # If the event in question is a backward extremity or has any of its
1796 # prev_events listed as a backward extremity, it's next to a
1797 # backward gap.
1798 clause, args = make_in_list_sql_clause(
1799 self.database_engine,
1800 "event_id",
1801 [event.event_id] + list(event.prev_event_ids()),
1802 )
1803
1804 txn.execute(backward_extremity_query % (clause,), [event.room_id] + args)
1805 backward_extremities = txn.fetchall()
1806
1807 # We consider any backward extremity as a backward gap
1808 if len(backward_extremities):
1809 return True
1810
1811 return False
1812
1813 return await self.db_pool.runInteraction(
1814 "is_event_next_to_backward_gap_txn",
1815 is_event_next_to_backward_gap_txn,
1816 )
1817
1818 async def is_event_next_to_forward_gap(self, event: EventBase) -> bool:
1819 """Check if the given event is next to a forward gap of missing events.
1820 The gap in front of the latest events is not considered a gap.
1821 <latest messages> A(False)--->B(False)--->C(False)---> <gap, unknown events> <oldest messages>
1822 <latest messages> A(False)--->B(False)---> <gap, unknown events> --->D(True)--->E(False) <oldest messages>
1823
1824 Args:
1825 room_id: room where the event lives
1826 event_id: event to check
1827
1828 Returns:
1829 Boolean indicating whether it's an extremity
1830 """
1831
1832 def is_event_next_to_gap_txn(txn: LoggingTransaction) -> bool:
1833 # If the event in question is a forward extremity, we will just
1834 # consider any potential forward gap as not a gap since it's one of
1835 # the latest events in the room.
1836 #
1837 # `event_forward_extremities` does not include backfilled or outlier
1838 # events so we can't rely on it to find forward gaps. We can only
1839 # use it to determine whether a message is the latest in the room.
1840 #
1841 # We can't combine this query with the `forward_edge_query` below
1842 # because if the event in question has no forward edges (isn't
1843 # referenced by any other event's prev_events) but is in
1844 # `event_forward_extremities`, we don't want to return 0 rows and
1845 # say it's next to a gap.
1846 forward_extremity_query = """
1847 SELECT 1 FROM event_forward_extremities
1848 WHERE
1849 room_id = ?
1850 AND event_id = ?
1851 LIMIT 1
1852 """
1853
1854 # Check to see whether the event in question is already referenced
1855 # by another event. If we don't see any edges, we're next to a
1856 # forward gap.
1857 forward_edge_query = """
1858 SELECT 1 FROM event_edges
1859 /* Check to make sure the event referencing our event in question is not rejected */
1860 LEFT JOIN rejections ON event_edges.event_id == rejections.event_id
1861 WHERE
1862 event_edges.room_id = ?
1863 AND event_edges.prev_event_id = ?
1864 /* It's not a valid edge if the event referencing our event in
1865 * question is rejected.
1866 */
1867 AND rejections.event_id IS NULL
1868 LIMIT 1
1869 """
1870
1871 # We consider any forward extremity as the latest in the room and
1872 # not a forward gap.
1873 #
1874 # To expand, even though there is technically a gap at the front of
1875 # the room where the forward extremities are, we consider those the
1876 # latest messages in the room so asking other homeservers for more
1877 # is useless. The new latest messages will just be federated as
1878 # usual.
1879 txn.execute(forward_extremity_query, (event.room_id, event.event_id))
1880 forward_extremities = txn.fetchall()
1881 if len(forward_extremities):
1882 return False
1883
1884 # If there are no forward edges to the event in question (another
1885 # event hasn't referenced this event in their prev_events), then we
1886 # assume there is a forward gap in the history.
1887 txn.execute(forward_edge_query, (event.room_id, event.event_id))
1888 forward_edges = txn.fetchall()
1889 if not len(forward_edges):
1890 return True
1891
1892 return False
1893
1894 return await self.db_pool.runInteraction(
1895 "is_event_next_to_gap_txn",
1896 is_event_next_to_gap_txn,
1897 )
1898
1899 async def get_event_id_for_timestamp(
1900 self, room_id: str, timestamp: int, direction: str
1901 ) -> Optional[str]:
1902 """Find the closest event to the given timestamp in the given direction.
1903
1904 Args:
1905 room_id: Room to fetch the event from
1906 timestamp: The point in time (inclusive) we should navigate from in
1907 the given direction to find the closest event.
1908 direction: ["f"|"b"] to indicate whether we should navigate forward
1909 or backward from the given timestamp to find the closest event.
1910
1911 Returns:
1912 The closest event_id otherwise None if we can't find any event in
1913 the given direction.
1914 """
1915
1916 sql_template = """
1917 SELECT event_id FROM events
1918 LEFT JOIN rejections USING (event_id)
1919 WHERE
1920 origin_server_ts %s ?
1921 AND room_id = ?
1922 /* Make sure event is not rejected */
1923 AND rejections.event_id IS NULL
1924 ORDER BY origin_server_ts %s
1925 LIMIT 1;
1926 """
1927
1928 def get_event_id_for_timestamp_txn(txn: LoggingTransaction) -> Optional[str]:
1929 if direction == "b":
1930 # Find closest event *before* a given timestamp. We use descending
1931 # (which gives values largest to smallest) because we want the
1932 # largest possible timestamp *before* the given timestamp.
1933 comparison_operator = "<="
1934 order = "DESC"
1935 else:
1936 # Find closest event *after* a given timestamp. We use ascending
1937 # (which gives values smallest to largest) because we want the
1938 # closest possible timestamp *after* the given timestamp.
1939 comparison_operator = ">="
1940 order = "ASC"
1941
1942 txn.execute(
1943 sql_template % (comparison_operator, order), (timestamp, room_id)
1944 )
1945 row = txn.fetchone()
1946 if row:
1947 (event_id,) = row
1948 return event_id
1949
1950 return None
1951
1952 if direction not in ("f", "b"):
1953 raise ValueError("Unknown direction: %s" % (direction,))
1954
1955 return await self.db_pool.runInteraction(
1956 "get_event_id_for_timestamp_txn",
1957 get_event_id_for_timestamp_txn,
1958 )
117117
118118 logger.info("[purge] looking for events to delete")
119119
120 should_delete_expr = "state_key IS NULL"
120 should_delete_expr = "state_events.state_key IS NULL"
121121 should_delete_params: Tuple[Any, ...] = ()
122122 if not delete_local_events:
123123 should_delete_expr += " AND event_id NOT LIKE ?"
2727 from synapse.storage.databases.main.roommember import RoomMemberWorkerStore
2828 from synapse.storage.engines import PostgresEngine, Sqlite3Engine
2929 from synapse.storage.push_rule import InconsistentRuleException, RuleNotFoundException
30 from synapse.storage.util.id_generators import StreamIdGenerator
30 from synapse.storage.util.id_generators import (
31 AbstractStreamIdTracker,
32 StreamIdGenerator,
33 )
3134 from synapse.util import json_encoder
3235 from synapse.util.caches.descriptors import cached, cachedList
3336 from synapse.util.caches.stream_change_cache import StreamChangeCache
8184 super().__init__(database, db_conn, hs)
8285
8386 if hs.config.worker.worker_app is None:
84 self._push_rules_stream_id_gen: Union[
85 StreamIdGenerator, SlavedIdTracker
86 ] = StreamIdGenerator(db_conn, "push_rules_stream", "stream_id")
87 self._push_rules_stream_id_gen: AbstractStreamIdTracker = StreamIdGenerator(
88 db_conn, "push_rules_stream", "stream_id"
89 )
8790 else:
8891 self._push_rules_stream_id_gen = SlavedIdTracker(
8992 db_conn, "push_rules_stream", "stream_id"
104104
105105 has_next_access_token_been_used: bool
106106 """True if the next access token was already used at least once."""
107
108 expiry_ts: Optional[int]
109 """The time at which the refresh token expires and can not be used.
110 If None, the refresh token doesn't expire."""
111
112 ultimate_session_expiry_ts: Optional[int]
113 """The time at which the session comes to an end and can no longer be
114 refreshed.
115 If None, the session can be refreshed indefinitely."""
107116
108117
109118 class RegistrationWorkerStore(CacheInvalidationWorkerStore):
16251634 rt.user_id,
16261635 rt.device_id,
16271636 rt.next_token_id,
1628 (nrt.next_token_id IS NOT NULL) has_next_refresh_token_been_refreshed,
1629 at.used has_next_access_token_been_used
1637 (nrt.next_token_id IS NOT NULL) AS has_next_refresh_token_been_refreshed,
1638 at.used AS has_next_access_token_been_used,
1639 rt.expiry_ts,
1640 rt.ultimate_session_expiry_ts
16301641 FROM refresh_tokens rt
16311642 LEFT JOIN refresh_tokens nrt ON rt.next_token_id = nrt.id
16321643 LEFT JOIN access_tokens at ON at.refresh_token_id = nrt.id
16471658 has_next_refresh_token_been_refreshed=row[4],
16481659 # This column is nullable, ensure it's a boolean
16491660 has_next_access_token_been_used=(row[5] or False),
1661 expiry_ts=row[6],
1662 ultimate_session_expiry_ts=row[7],
16501663 )
16511664
16521665 return await self.db_pool.runInteraction(
19141927 user_id: str,
19151928 token: str,
19161929 device_id: Optional[str],
1930 expiry_ts: Optional[int],
1931 ultimate_session_expiry_ts: Optional[int],
19171932 ) -> int:
19181933 """Adds a refresh token for the given user.
19191934
19211936 user_id: The user ID.
19221937 token: The new access token to add.
19231938 device_id: ID of the device to associate with the refresh token.
1939 expiry_ts (milliseconds since the epoch): Time after which the
1940 refresh token cannot be used.
1941 If None, the refresh token never expires until it has been used.
1942 ultimate_session_expiry_ts (milliseconds since the epoch):
1943 Time at which the session will end and can not be extended any
1944 further.
1945 If None, the session can be refreshed indefinitely.
19241946 Raises:
19251947 StoreError if there was a problem adding this.
19261948 Returns:
19361958 "device_id": device_id,
19371959 "token": token,
19381960 "next_token_id": None,
1961 "expiry_ts": expiry_ts,
1962 "ultimate_session_expiry_ts": ultimate_session_expiry_ts,
19391963 },
19401964 desc="add_refresh_token_to_user",
19411965 )
475475 INNER JOIN events AS e USING (room_id, event_id)
476476 WHERE
477477 c.type = 'm.room.member'
478 AND state_key = ?
478 AND c.state_key = ?
479479 AND c.membership = ?
480480 """
481481 else:
486486 INNER JOIN events AS e USING (room_id, event_id)
487487 WHERE
488488 c.type = 'm.room.member'
489 AND state_key = ?
489 AND c.state_key = ?
490490 AND m.membership = ?
491491 """
492492
496496 oldest `limit` events.
497497
498498 Returns:
499 The list of events (in ascending order) and the token from the start
499 The list of events (in ascending stream order) and the token from the start
500500 of the chunk of events returned.
501501 """
502502 if from_key == to_key:
509509 if not has_changed:
510510 return [], from_key
511511
512 def f(txn):
512 def f(txn: LoggingTransaction) -> List[_EventDictReturn]:
513513 # To handle tokens with a non-empty instance_map we fetch more
514514 # results than necessary and then filter down
515515 min_from_id = from_key.stream
564564 async def get_membership_changes_for_user(
565565 self, user_id: str, from_key: RoomStreamToken, to_key: RoomStreamToken
566566 ) -> List[EventBase]:
567 """Fetch membership events for a given user.
568
569 All such events whose stream ordering `s` lies in the range
570 `from_key < s <= to_key` are returned. Events are ordered by ascending stream
571 order.
572 """
573 # Start by ruling out cases where a DB query is not necessary.
567574 if from_key == to_key:
568575 return []
569576
574581 if not has_changed:
575582 return []
576583
577 def f(txn):
584 def f(txn: LoggingTransaction) -> List[_EventDictReturn]:
578585 # To handle tokens with a non-empty instance_map we fetch more
579586 # results than necessary and then filter down
580587 min_from_id = from_key.stream
633640
634641 Returns:
635642 A list of events and a token pointing to the start of the returned
636 events. The events returned are in ascending order.
643 events. The events returned are in ascending topological order.
637644 """
638645
639646 rows, token = await self.get_recent_event_ids_for_room(
1313
1414 import logging
1515 from collections import namedtuple
16 from enum import Enum
1617 from typing import TYPE_CHECKING, Iterable, List, Optional, Tuple
1718
1819 import attr
4142 _UpdateTransactionRow = namedtuple(
4243 "_TransactionRow", ("response_code", "response_json")
4344 )
45
46
47 class DestinationSortOrder(Enum):
48 """Enum to define the sorting method used when returning destinations."""
49
50 DESTINATION = "destination"
51 RETRY_LAST_TS = "retry_last_ts"
52 RETTRY_INTERVAL = "retry_interval"
53 FAILURE_TS = "failure_ts"
54 LAST_SUCCESSFUL_STREAM_ORDERING = "last_successful_stream_ordering"
4455
4556
4657 @attr.s(slots=True, frozen=True, auto_attribs=True)
479490
480491 destinations = [row[0] for row in txn]
481492 return destinations
493
494 async def get_destinations_paginate(
495 self,
496 start: int,
497 limit: int,
498 destination: Optional[str] = None,
499 order_by: str = DestinationSortOrder.DESTINATION.value,
500 direction: str = "f",
501 ) -> Tuple[List[JsonDict], int]:
502 """Function to retrieve a paginated list of destinations.
503 This will return a json list of destinations and the
504 total number of destinations matching the filter criteria.
505
506 Args:
507 start: start number to begin the query from
508 limit: number of rows to retrieve
509 destination: search string in destination
510 order_by: the sort order of the returned list
511 direction: sort ascending or descending
512 Returns:
513 A tuple of a list of mappings from destination to information
514 and a count of total destinations.
515 """
516
517 def get_destinations_paginate_txn(
518 txn: LoggingTransaction,
519 ) -> Tuple[List[JsonDict], int]:
520 order_by_column = DestinationSortOrder(order_by).value
521
522 if direction == "b":
523 order = "DESC"
524 else:
525 order = "ASC"
526
527 args = []
528 where_statement = ""
529 if destination:
530 args.extend(["%" + destination.lower() + "%"])
531 where_statement = "WHERE LOWER(destination) LIKE ?"
532
533 sql_base = f"FROM destinations {where_statement} "
534 sql = f"SELECT COUNT(*) as total_destinations {sql_base}"
535 txn.execute(sql, args)
536 count = txn.fetchone()[0]
537
538 sql = f"""
539 SELECT destination, retry_last_ts, retry_interval, failure_ts,
540 last_successful_stream_ordering
541 {sql_base}
542 ORDER BY {order_by_column} {order}, destination ASC
543 LIMIT ? OFFSET ?
544 """
545 txn.execute(sql, args + [limit, start])
546 destinations = self.db_pool.cursor_to_dict(txn)
547 return destinations, count
548
549 return await self.db_pool.runInteraction(
550 "get_destinations_paginate_txn", get_destinations_paginate_txn
551 )
582582 current_state_for_room=current_state_for_room,
583583 state_delta_for_room=state_delta_for_room,
584584 new_forward_extremeties=new_forward_extremeties,
585 backfilled=backfilled,
585 use_negative_stream_ordering=backfilled,
586 inhibit_local_membership_updates=backfilled,
586587 )
587588
588589 await self._handle_potentially_left_users(potentially_left_users)
1111 # See the License for the specific language governing permissions and
1212 # limitations under the License.
1313
14 SCHEMA_VERSION = 65 # remember to update the list below when updating
14 SCHEMA_VERSION = 66 # remember to update the list below when updating
1515 """Represents the expectations made by the codebase about the database schema
1616
1717 This should be incremented whenever the codebase changes its requirements on the
4545 - MSC2716: Remove unique event_id constraint from insertion_event_edges
4646 because an insertion event can have multiple edges.
4747 - Remove unused tables `user_stats_historical` and `room_stats_historical`.
48
49 Changes in SCHEMA_VERSION = 66:
50 - Queries on state_key columns are now disambiguated (ie, the codebase can handle
51 the `events` table having a `state_key` column).
4852 """
4953
5054
0 /* Copyright 2021 The Matrix.org Foundation C.I.C
1 *
2 * Licensed under the Apache License, Version 2.0 (the "License");
3 * you may not use this file except in compliance with the License.
4 * You may obtain a copy of the License at
5 *
6 * http://www.apache.org/licenses/LICENSE-2.0
7 *
8 * Unless required by applicable law or agreed to in writing, software
9 * distributed under the License is distributed on an "AS IS" BASIS,
10 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
11 * See the License for the specific language governing permissions and
12 * limitations under the License.
13 */
14
15
16 ALTER TABLE refresh_tokens
17 -- We add an expiry_ts column (in milliseconds since the Epoch) to refresh tokens.
18 -- They may not be used after they have expired.
19 -- If null, then the refresh token's lifetime is unlimited.
20 ADD COLUMN expiry_ts BIGINT DEFAULT NULL;
21
22 ALTER TABLE refresh_tokens
23 -- We also add an ultimate session expiry time (in milliseconds since the Epoch).
24 -- No matter how much the access and refresh tokens are refreshed, they cannot
25 -- be extended past this time.
26 -- If null, then the session length is unlimited.
27 ADD COLUMN ultimate_session_expiry_ts BIGINT DEFAULT NULL;
0 /* Copyright 2021 The Matrix.org Foundation C.I.C
1 *
2 * Licensed under the Apache License, Version 2.0 (the "License");
3 * you may not use this file except in compliance with the License.
4 * You may obtain a copy of the License at
5 *
6 * http://www.apache.org/licenses/LICENSE-2.0
7 *
8 * Unless required by applicable law or agreed to in writing, software
9 * distributed under the License is distributed on an "AS IS" BASIS,
10 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
11 * See the License for the specific language governing permissions and
12 * limitations under the License.
13 */
14
15 -- Track the auth provider used by each login as well as the session ID
16 CREATE TABLE device_auth_providers (
17 user_id TEXT NOT NULL,
18 device_id TEXT NOT NULL,
19 auth_provider_id TEXT NOT NULL,
20 auth_provider_session_id TEXT NOT NULL
21 );
22
23 CREATE INDEX device_auth_providers_devices
24 ON device_auth_providers (user_id, device_id);
25 CREATE INDEX device_auth_providers_sessions
26 ON device_auth_providers (auth_provider_id, auth_provider_session_id);
8888 return (max if step > 0 else min)(current_id, step)
8989
9090
91 class AbstractStreamIdGenerator(metaclass=abc.ABCMeta):
91 class AbstractStreamIdTracker(metaclass=abc.ABCMeta):
92 """Tracks the "current" stream ID of a stream that may have multiple writers.
93
94 Stream IDs are monotonically increasing or decreasing integers representing write
95 transactions. The "current" stream ID is the stream ID such that all transactions
96 with equal or smaller stream IDs have completed. Since transactions may complete out
97 of order, this is not the same as the stream ID of the last completed transaction.
98
99 Completed transactions include both committed transactions and transactions that
100 have been rolled back.
101 """
102
103 @abc.abstractmethod
104 def advance(self, instance_name: str, new_id: int) -> None:
105 """Advance the position of the named writer to the given ID, if greater
106 than existing entry.
107 """
108 raise NotImplementedError()
109
110 @abc.abstractmethod
111 def get_current_token(self) -> int:
112 """Returns the maximum stream id such that all stream ids less than or
113 equal to it have been successfully persisted.
114
115 Returns:
116 The maximum stream id.
117 """
118 raise NotImplementedError()
119
120 @abc.abstractmethod
121 def get_current_token_for_writer(self, instance_name: str) -> int:
122 """Returns the position of the given writer.
123
124 For streams with single writers this is equivalent to `get_current_token`.
125 """
126 raise NotImplementedError()
127
128
129 class AbstractStreamIdGenerator(AbstractStreamIdTracker):
130 """Generates stream IDs for a stream that may have multiple writers.
131
132 Each stream ID represents a write transaction, whose completion is tracked
133 so that the "current" stream ID of the stream can be determined.
134
135 See `AbstractStreamIdTracker` for more details.
136 """
137
92138 @abc.abstractmethod
93139 def get_next(self) -> AsyncContextManager[int]:
140 """
141 Usage:
142 async with stream_id_gen.get_next() as stream_id:
143 # ... persist event ...
144 """
94145 raise NotImplementedError()
95146
96147 @abc.abstractmethod
97148 def get_next_mult(self, n: int) -> AsyncContextManager[Sequence[int]]:
149 """
150 Usage:
151 async with stream_id_gen.get_next(n) as stream_ids:
152 # ... persist events ...
153 """
98154 raise NotImplementedError()
99155
100 @abc.abstractmethod
101 def get_current_token(self) -> int:
102 raise NotImplementedError()
103
104 @abc.abstractmethod
105 def get_current_token_for_writer(self, instance_name: str) -> int:
106 raise NotImplementedError()
107
108156
109157 class StreamIdGenerator(AbstractStreamIdGenerator):
110 """Used to generate new stream ids when persisting events while keeping
111 track of which transactions have been completed.
112
113 This allows us to get the "current" stream id, i.e. the stream id such that
114 all ids less than or equal to it have completed. This handles the fact that
115 persistence of events can complete out of order.
158 """Generates and tracks stream IDs for a stream with a single writer.
159
160 This class must only be used when the current Synapse process is the sole
161 writer for a stream.
116162
117163 Args:
118164 db_conn(connection): A database connection to use to fetch the
156202 # The key and values are the same, but we never look at the values.
157203 self._unfinished_ids: OrderedDict[int, int] = OrderedDict()
158204
205 def advance(self, instance_name: str, new_id: int) -> None:
206 # `StreamIdGenerator` should only be used when there is a single writer,
207 # so replication should never happen.
208 raise Exception("Replication is not supported by StreamIdGenerator")
209
159210 def get_next(self) -> AsyncContextManager[int]:
160 """
161 Usage:
162 async with stream_id_gen.get_next() as stream_id:
163 # ... persist event ...
164 """
165211 with self._lock:
166212 self._current += self._step
167213 next_id = self._current
179225 return _AsyncCtxManagerWrapper(manager())
180226
181227 def get_next_mult(self, n: int) -> AsyncContextManager[Sequence[int]]:
182 """
183 Usage:
184 async with stream_id_gen.get_next(n) as stream_ids:
185 # ... persist events ...
186 """
187228 with self._lock:
188229 next_ids = range(
189230 self._current + self._step,
207248 return _AsyncCtxManagerWrapper(manager())
208249
209250 def get_current_token(self) -> int:
210 """Returns the maximum stream id such that all stream ids less than or
211 equal to it have been successfully persisted.
212
213 Returns:
214 The maximum stream id.
215 """
216251 with self._lock:
217252 if self._unfinished_ids:
218253 return next(iter(self._unfinished_ids)) - self._step
220255 return self._current
221256
222257 def get_current_token_for_writer(self, instance_name: str) -> int:
223 """Returns the position of the given writer.
224
225 For streams with single writers this is equivalent to
226 `get_current_token`.
227 """
228258 return self.get_current_token()
229259
230260
231261 class MultiWriterIdGenerator(AbstractStreamIdGenerator):
232 """An ID generator that tracks a stream that can have multiple writers.
262 """Generates and tracks stream IDs for a stream with multiple writers.
233263
234264 Uses a Postgres sequence to coordinate ID assignment, but positions of other
235265 writers will only get updated when `advance` is called (by replication).
474504 return stream_ids
475505
476506 def get_next(self) -> AsyncContextManager[int]:
477 """
478 Usage:
479 async with stream_id_gen.get_next() as stream_id:
480 # ... persist event ...
481 """
482
483507 # If we have a list of instances that are allowed to write to this
484508 # stream, make sure we're in it.
485509 if self._writers and self._instance_name not in self._writers:
491515 return cast(AsyncContextManager[int], _MultiWriterCtxManager(self))
492516
493517 def get_next_mult(self, n: int) -> AsyncContextManager[List[int]]:
494 """
495 Usage:
496 async with stream_id_gen.get_next_mult(5) as stream_ids:
497 # ... persist events ...
498 """
499
500518 # If we have a list of instances that are allowed to write to this
501519 # stream, make sure we're in it.
502520 if self._writers and self._instance_name not in self._writers:
596614 self._add_persisted_position(next_id)
597615
598616 def get_current_token(self) -> int:
599 """Returns the maximum stream id such that all stream ids less than or
600 equal to it have been successfully persisted.
601 """
602
603617 return self.get_persisted_upto_position()
604618
605619 def get_current_token_for_writer(self, instance_name: str) -> int:
606 """Returns the position of the given writer."""
607
608620 # If we don't have an entry for the given instance name, we assume it's a
609621 # new writer.
610622 #
630642 }
631643
632644 def advance(self, instance_name: str, new_id: int) -> None:
633 """Advance the position of the named writer to the given ID, if greater
634 than existing entry.
635 """
636
637645 new_id *= self._return_factor
638646
639647 with self._lock:
2121 Iterable,
2222 MutableMapping,
2323 Optional,
24 Sized,
2425 TypeVar,
2526 Union,
2627 cast,
103104 max_size=max_entries,
104105 cache_name=name,
105106 cache_type=cache_type,
106 size_callback=(lambda d: len(d) or 1) if iterable else None,
107 size_callback=(
108 (lambda d: len(cast(Sized, d)) or 1)
109 # Argument 1 to "len" has incompatible type "VT"; expected "Sized"
110 # We trust that `VT` is `Sized` when `iterable` is `True`
111 if iterable
112 else None
113 ),
107114 metrics_collection_callback=metrics_cb,
108115 apply_cache_factor_from_config=apply_cache_factor_from_config,
109116 prune_unread_entries=prune_unread_entries,
1414 import logging
1515 import threading
1616 import weakref
17 from enum import Enum
1718 from functools import wraps
1819 from typing import (
1920 TYPE_CHECKING,
2021 Any,
2122 Callable,
2223 Collection,
24 Dict,
2325 Generic,
24 Iterable,
2526 List,
2627 Optional,
2728 Type,
189190 root: "ListNode[_Node]",
190191 key: KT,
191192 value: VT,
192 cache: "weakref.ReferenceType[LruCache]",
193 cache: "weakref.ReferenceType[LruCache[KT, VT]]",
193194 clock: Clock,
194195 callbacks: Collection[Callable[[], None]] = (),
195196 prune_unread_entries: bool = True,
269270 removed from all lists.
270271 """
271272 cache = self._cache()
272 if not cache or not cache.pop(self.key, None):
273 if (
274 cache is None
275 or cache.pop(self.key, _Sentinel.sentinel) is _Sentinel.sentinel
276 ):
273277 # `cache.pop` should call `drop_from_lists()`, unless this Node had
274278 # already been removed from the cache.
275279 self.drop_from_lists()
289293 self._global_list_node.update_last_access(clock)
290294
291295
296 class _Sentinel(Enum):
297 # defining a sentinel in this way allows mypy to correctly handle the
298 # type of a dictionary lookup.
299 sentinel = object()
300
301
292302 class LruCache(Generic[KT, VT]):
293303 """
294304 Least-recently-used cache, supporting prometheus metrics and invalidation callbacks.
301311 max_size: int,
302312 cache_name: Optional[str] = None,
303313 cache_type: Type[Union[dict, TreeCache]] = dict,
304 size_callback: Optional[Callable] = None,
314 size_callback: Optional[Callable[[VT], int]] = None,
305315 metrics_collection_callback: Optional[Callable[[], None]] = None,
306316 apply_cache_factor_from_config: bool = True,
307317 clock: Optional[Clock] = None,
338348 else:
339349 real_clock = clock
340350
341 cache = cache_type()
351 cache: Union[Dict[KT, _Node[KT, VT]], TreeCache] = cache_type()
342352 self.cache = cache # Used for introspection.
343353 self.apply_cache_factor_from_config = apply_cache_factor_from_config
344354
373383 # creating more each time we create a `_Node`.
374384 weak_ref_to_self = weakref.ref(self)
375385
376 list_root = ListNode[_Node].create_root_node()
386 list_root = ListNode[_Node[KT, VT]].create_root_node()
377387
378388 lock = threading.Lock()
379389
421431 def add_node(
422432 key: KT, value: VT, callbacks: Collection[Callable[[], None]] = ()
423433 ) -> None:
424 node = _Node(
434 node: _Node[KT, VT] = _Node(
425435 list_root,
426436 key,
427437 value,
438448 if caches.TRACK_MEMORY_USAGE and metrics:
439449 metrics.inc_memory_usage(node.memory)
440450
441 def move_node_to_front(node: _Node) -> None:
451 def move_node_to_front(node: _Node[KT, VT]) -> None:
442452 node.move_to_front(real_clock, list_root)
443453
444 def delete_node(node: _Node) -> int:
454 def delete_node(node: _Node[KT, VT]) -> int:
445455 node.drop_from_lists()
446456
447457 deleted_len = 1
495505
496506 @synchronized
497507 def cache_set(
498 key: KT, value: VT, callbacks: Iterable[Callable[[], None]] = ()
508 key: KT, value: VT, callbacks: Collection[Callable[[], None]] = ()
499509 ) -> None:
500510 node = cache.get(key, None)
501511 if node is not None:
589599 def cache_contains(key: KT) -> bool:
590600 return key in cache
591601
592 self.sentinel = object()
593
594602 # make sure that we clear out any excess entries after we get resized.
595603 self._on_resize = evict
596604
607615 self.clear = cache_clear
608616
609617 def __getitem__(self, key: KT) -> VT:
610 result = self.get(key, self.sentinel)
611 if result is self.sentinel:
618 result = self.get(key, _Sentinel.sentinel)
619 if result is _Sentinel.sentinel:
612620 raise KeyError()
613621 else:
614 return cast(VT, result)
622 return result
615623
616624 def __setitem__(self, key: KT, value: VT) -> None:
617625 self.set(key, value)
618626
619627 def __delitem__(self, key: KT, value: VT) -> None:
620 result = self.pop(key, self.sentinel)
621 if result is self.sentinel:
628 result = self.pop(key, _Sentinel.sentinel)
629 if result is _Sentinel.sentinel:
622630 raise KeyError()
623631
624632 def __len__(self) -> int:
8383 # immediately rather than at the next GC.
8484 self.cache_entry = None
8585
86 def move_after(self, node: "ListNode") -> None:
86 def move_after(self, node: "ListNode[P]") -> None:
8787 """Move this node from its current location in the list to after the
8888 given node.
8989 """
121121 self.prev_node = None
122122 self.next_node = None
123123
124 def _refs_insert_after(self, node: "ListNode") -> None:
124 def _refs_insert_after(self, node: "ListNode[P]") -> None:
125125 """Internal method to insert the node after the given node."""
126126
127127 # This method should only be called when we're not already in the list.
00 # Copyright 2016 OpenMarket Ltd
1 # Copyright 2021 The Matrix.org Foundation C.I.C.
12 #
23 # Licensed under the Apache License, Version 2.0 (the "License");
34 # you may not use this file except in compliance with the License.
2829 If called on a module not in a git checkout will return `__version__`.
2930
3031 Args:
31 module (module)
32 module: The module to check the version of. Must declare a __version__
33 attribute.
3234
3335 Returns:
34 str
36 The module version (as a string).
3537 """
3638
3739 cached_version = version_cache.get(module)
4345 version_string = module.__version__ # type: ignore[attr-defined]
4446
4547 try:
46 null = open(os.devnull, "w")
4748 cwd = os.path.dirname(os.path.abspath(module.__file__))
4849
49 try:
50 git_branch = (
51 subprocess.check_output(
52 ["git", "rev-parse", "--abbrev-ref", "HEAD"], stderr=null, cwd=cwd
50 def _run_git_command(prefix: str, *params: str) -> str:
51 try:
52 result = (
53 subprocess.check_output(
54 ["git", *params], stderr=subprocess.DEVNULL, cwd=cwd
55 )
56 .strip()
57 .decode("ascii")
5358 )
54 .strip()
55 .decode("ascii")
56 )
57 git_branch = "b=" + git_branch
58 except (subprocess.CalledProcessError, FileNotFoundError):
59 # FileNotFoundError can arise when git is not installed
60 git_branch = ""
59 return prefix + result
60 except (subprocess.CalledProcessError, FileNotFoundError):
61 return ""
6162
62 try:
63 git_tag = (
64 subprocess.check_output(
65 ["git", "describe", "--exact-match"], stderr=null, cwd=cwd
66 )
67 .strip()
68 .decode("ascii")
69 )
70 git_tag = "t=" + git_tag
71 except (subprocess.CalledProcessError, FileNotFoundError):
72 git_tag = ""
63 git_branch = _run_git_command("b=", "rev-parse", "--abbrev-ref", "HEAD")
64 git_tag = _run_git_command("t=", "describe", "--exact-match")
65 git_commit = _run_git_command("", "rev-parse", "--short", "HEAD")
7366
74 try:
75 git_commit = (
76 subprocess.check_output(
77 ["git", "rev-parse", "--short", "HEAD"], stderr=null, cwd=cwd
78 )
79 .strip()
80 .decode("ascii")
81 )
82 except (subprocess.CalledProcessError, FileNotFoundError):
83 git_commit = ""
84
85 try:
86 dirty_string = "-this_is_a_dirty_checkout"
87 is_dirty = (
88 subprocess.check_output(
89 ["git", "describe", "--dirty=" + dirty_string], stderr=null, cwd=cwd
90 )
91 .strip()
92 .decode("ascii")
93 .endswith(dirty_string)
94 )
95
96 git_dirty = "dirty" if is_dirty else ""
97 except (subprocess.CalledProcessError, FileNotFoundError):
98 git_dirty = ""
67 dirty_string = "-this_is_a_dirty_checkout"
68 is_dirty = _run_git_command("", "describe", "--dirty=" + dirty_string).endswith(
69 dirty_string
70 )
71 git_dirty = "dirty" if is_dirty else ""
9972
10073 if git_branch or git_tag or git_commit or git_dirty:
10174 git_version = ",".join(
10275 s for s in (git_branch, git_tag, git_commit, git_dirty) if s
10376 )
10477
105 version_string = "%s (%s)" % (
106 # If the __version__ attribute doesn't exist, we'll have failed
107 # loudly above.
108 module.__version__, # type: ignore[attr-defined]
109 git_version,
110 )
78 version_string = f"{version_string} ({git_version})"
11179 except Exception as e:
11280 logger.info("Failed to check for git repository: %s", e)
11381
2323 import subprocess
2424 import sys
2525 import time
26 from typing import Iterable
26 from typing import Iterable, Optional
2727
2828 import yaml
2929
4040 def pid_running(pid):
4141 try:
4242 os.kill(pid, 0)
43 return True
4443 except OSError as err:
4544 if err.errno == errno.EPERM:
46 return True
47 return False
45 pass # process exists
46 else:
47 return False
48
49 # When running in a container, orphan processes may not get reaped and their
50 # PIDs may remain valid. Try to work around the issue.
51 try:
52 with open(f"/proc/{pid}/status") as status_file:
53 if "zombie" in status_file.read():
54 return False
55 except Exception:
56 # This isn't Linux or `/proc/` is unavailable.
57 # Assume that the process is still running.
58 pass
59
60 return True
4861
4962
5063 def write(message, colour=NORMAL, stream=sys.stdout):
108121 return False
109122
110123
111 def stop(pidfile: str, app: str) -> bool:
124 def stop(pidfile: str, app: str) -> Optional[int]:
112125 """Attempts to kill a synapse worker from the pidfile.
113126 Args:
114127 pidfile: path to file containing worker's pid
115128 app: name of the worker's appservice
116129
117130 Returns:
118 True if the process stopped successfully
119 False if process was already stopped or an error occured
131 process id, or None if the process was not running
120132 """
121133
122134 if os.path.exists(pidfile):
124136 try:
125137 os.kill(pid, signal.SIGTERM)
126138 write("stopped %s" % (app,), colour=GREEN)
127 return True
139 return pid
128140 except OSError as err:
129141 if err.errno == errno.ESRCH:
130142 write("%s not running" % (app,), colour=YELLOW)
132144 abort("Cannot stop %s: Operation not permitted" % (app,))
133145 else:
134146 abort("Cannot stop %s: Unknown error" % (app,))
135 return False
136147 else:
137148 write(
138149 "No running worker of %s found (from %s)\nThe process might be managed by another controller (e.g. systemd)"
139150 % (app, pidfile),
140151 colour=YELLOW,
141152 )
142 return False
153 return None
143154
144155
145156 Worker = collections.namedtuple(
287298 action = options.action
288299
289300 if action == "stop" or action == "restart":
290 has_stopped = True
301 running_pids = []
291302 for worker in workers:
292 if not stop(worker.pidfile, worker.app):
293 # A worker could not be stopped.
294 has_stopped = False
303 pid = stop(worker.pidfile, worker.app)
304 if pid is not None:
305 running_pids.append(pid)
295306
296307 if start_stop_synapse:
297 if not stop(pidfile, MAIN_PROCESS):
298 has_stopped = False
299 if not has_stopped and action == "stop":
300 sys.exit(1)
301
302 # Wait for synapse to actually shutdown before starting it again
303 if action == "restart":
304 running_pids = []
305 if start_stop_synapse and os.path.exists(pidfile):
306 running_pids.append(int(open(pidfile).read()))
307 for worker in workers:
308 if os.path.exists(worker.pidfile):
309 running_pids.append(int(open(worker.pidfile).read()))
308 pid = stop(pidfile, MAIN_PROCESS)
309 if pid is not None:
310 running_pids.append(pid)
311
310312 if len(running_pids) > 0:
311 write("Waiting for process to exit before restarting...")
313 write("Waiting for processes to exit...")
312314 for running_pid in running_pids:
313315 while pid_running(running_pid):
314316 time.sleep(0.2)
315 write("All processes exited; now restarting...")
317 write("All processes exited")
316318
317319 if action == "start" or action == "restart":
318320 error = False
0 # Copyright 2021 The Matrix.org Foundation C.I.C.
1 #
2 # Licensed under the Apache License, Version 2.0 (the "License");
3 # you may not use this file except in compliance with the License.
4 # You may obtain a copy of the License at
5 #
6 # http://www.apache.org/licenses/LICENSE-2.0
7 #
8 # Unless required by applicable law or agreed to in writing, software
9 # distributed under the License is distributed on an "AS IS" BASIS,
10 # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
11 # See the License for the specific language governing permissions and
12 # limitations under the License.
13
14 import synapse.app.homeserver
15 from synapse.config._base import ConfigError
16
17 from tests.config.utils import ConfigFileTestCase
18
19
20 class HomeserverAppStartTestCase(ConfigFileTestCase):
21 def test_wrong_start_caught(self):
22 # Generate a config with a worker_app
23 self.generate_config()
24 # Add a blank line as otherwise the next addition ends up on a line with a comment
25 self.add_lines_to_config([" "])
26 self.add_lines_to_config(["worker_app: test_worker_app"])
27
28 # Ensure that starting master process with worker config raises an exception
29 with self.assertRaises(ConfigError):
30 synapse.app.homeserver.setup(["-c", self.config_file])
0 # Copyright 2021 The Matrix.org Foundation C.I.C.
1 #
2 # Licensed under the Apache License, Version 2.0 (the "License");
3 # you may not use this file except in compliance with the License.
4 # You may obtain a copy of the License at
5 #
6 # http://www.apache.org/licenses/LICENSE-2.0
7 #
8 # Unless required by applicable law or agreed to in writing, software
9 # distributed under the License is distributed on an "AS IS" BASIS,
10 # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
11 # See the License for the specific language governing permissions and
12 # limitations under the License.
13 from synapse.config import ConfigError
14 from synapse.config.homeserver import HomeServerConfig
15
16 from tests.unittest import TestCase
17 from tests.utils import default_config
18
19
20 class RegistrationConfigTestCase(TestCase):
21 def test_session_lifetime_must_not_be_exceeded_by_smaller_lifetimes(self):
22 """
23 session_lifetime should logically be larger than, or at least as large as,
24 all the different token lifetimes.
25 Test that the user is faced with configuration errors if they make it
26 smaller, as that configuration doesn't make sense.
27 """
28 config_dict = default_config("test")
29
30 # First test all the error conditions
31 with self.assertRaises(ConfigError):
32 HomeServerConfig().parse_config_dict(
33 {
34 "session_lifetime": "30m",
35 "nonrefreshable_access_token_lifetime": "31m",
36 **config_dict,
37 }
38 )
39
40 with self.assertRaises(ConfigError):
41 HomeServerConfig().parse_config_dict(
42 {
43 "session_lifetime": "30m",
44 "refreshable_access_token_lifetime": "31m",
45 **config_dict,
46 }
47 )
48
49 with self.assertRaises(ConfigError):
50 HomeServerConfig().parse_config_dict(
51 {
52 "session_lifetime": "30m",
53 "refresh_token_lifetime": "31m",
54 **config_dict,
55 }
56 )
57
58 # Then test all the fine conditions
59 HomeServerConfig().parse_config_dict(
60 {
61 "session_lifetime": "31m",
62 "nonrefreshable_access_token_lifetime": "31m",
63 **config_dict,
64 }
65 )
66
67 HomeServerConfig().parse_config_dict(
68 {
69 "session_lifetime": "31m",
70 "refreshable_access_token_lifetime": "31m",
71 **config_dict,
72 }
73 )
74
75 HomeServerConfig().parse_config_dict(
76 {"session_lifetime": "31m", "refresh_token_lifetime": "31m", **config_dict}
77 )
2121 from nacl.signing import SigningKey
2222 from signedjson.key import encode_verify_key_base64, get_verify_key
2323
24 from twisted.internet import defer
2425 from twisted.internet.defer import Deferred, ensureDeferred
2526
2627 from synapse.api.errors import SynapseError
576577 bytes(res["key_json"]), canonicaljson.encode_canonical_json(response)
577578 )
578579
580 def test_get_multiple_keys_from_perspectives(self):
581 """Check that we can correctly request multiple keys for the same server"""
582
583 fetcher = PerspectivesKeyFetcher(self.hs)
584
585 SERVER_NAME = "server2"
586
587 testkey1 = signedjson.key.generate_signing_key("ver1")
588 testverifykey1 = signedjson.key.get_verify_key(testkey1)
589 testverifykey1_id = "ed25519:ver1"
590
591 testkey2 = signedjson.key.generate_signing_key("ver2")
592 testverifykey2 = signedjson.key.get_verify_key(testkey2)
593 testverifykey2_id = "ed25519:ver2"
594
595 VALID_UNTIL_TS = 200 * 1000
596
597 response1 = self.build_perspectives_response(
598 SERVER_NAME,
599 testkey1,
600 VALID_UNTIL_TS,
601 )
602 response2 = self.build_perspectives_response(
603 SERVER_NAME,
604 testkey2,
605 VALID_UNTIL_TS,
606 )
607
608 async def post_json(destination, path, data, **kwargs):
609 self.assertEqual(destination, self.mock_perspective_server.server_name)
610 self.assertEqual(path, "/_matrix/key/v2/query")
611
612 # check that the request is for the expected keys
613 q = data["server_keys"]
614
615 self.assertEqual(
616 list(q[SERVER_NAME].keys()), [testverifykey1_id, testverifykey2_id]
617 )
618 return {"server_keys": [response1, response2]}
619
620 self.http_client.post_json.side_effect = post_json
621
622 # fire off two separate requests; they should get merged together into a
623 # single HTTP hit.
624 request1_d = defer.ensureDeferred(
625 fetcher.get_keys(SERVER_NAME, [testverifykey1_id], 0)
626 )
627 request2_d = defer.ensureDeferred(
628 fetcher.get_keys(SERVER_NAME, [testverifykey2_id], 0)
629 )
630
631 keys1 = self.get_success(request1_d)
632 self.assertIn(testverifykey1_id, keys1)
633 k = keys1[testverifykey1_id]
634 self.assertEqual(k.valid_until_ts, VALID_UNTIL_TS)
635 self.assertEqual(k.verify_key, testverifykey1)
636 self.assertEqual(k.verify_key.alg, "ed25519")
637 self.assertEqual(k.verify_key.version, "ver1")
638
639 keys2 = self.get_success(request2_d)
640 self.assertIn(testverifykey2_id, keys2)
641 k = keys2[testverifykey2_id]
642 self.assertEqual(k.valid_until_ts, VALID_UNTIL_TS)
643 self.assertEqual(k.verify_key, testverifykey2)
644 self.assertEqual(k.verify_key.alg, "ed25519")
645 self.assertEqual(k.verify_key.version, "ver2")
646
647 # finally, ensure that only one request was sent
648 self.assertEqual(self.http_client.post_json.call_count, 1)
649
579650 def test_get_perspectives_own_key(self):
580651 """Check that we can get the perspectives server's own keys
581652
0 # Copyright 2021 The Matrix.org Foundation C.I.C.
1 #
2 # Licensed under the Apache License, Version 2.0 (the "License");
3 # you may not use this file except in compliance with the License.
4 # You may obtain a copy of the License at
5 #
6 # http://www.apache.org/licenses/LICENSE-2.0
7 #
8 # Unless required by applicable law or agreed to in writing, software
9 # distributed under the License is distributed on an "AS IS" BASIS,
10 # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
11 # See the License for the specific language governing permissions and
12 # limitations under the License.
13
14 import json
15
16 from synapse.api.room_versions import RoomVersions
17 from synapse.federation.transport.client import SendJoinParser
18
19 from tests.unittest import TestCase
20
21
22 class SendJoinParserTestCase(TestCase):
23 def test_two_writes(self) -> None:
24 """Test that the parser can sensibly deserialise an input given in two slices."""
25 parser = SendJoinParser(RoomVersions.V1, True)
26 parent_event = {
27 "content": {
28 "see_room_version_spec": "The event format changes depending on the room version."
29 },
30 "event_id": "$authparent",
31 "room_id": "!somewhere:example.org",
32 "type": "m.room.minimal_pdu",
33 }
34 state = {
35 "content": {
36 "see_room_version_spec": "The event format changes depending on the room version."
37 },
38 "event_id": "$DoNotThinkAboutTheEvent",
39 "room_id": "!somewhere:example.org",
40 "type": "m.room.minimal_pdu",
41 }
42 response = [
43 200,
44 {
45 "auth_chain": [parent_event],
46 "origin": "matrix.org",
47 "state": [state],
48 },
49 ]
50 serialised_response = json.dumps(response).encode()
51
52 # Send data to the parser
53 parser.write(serialised_response[:100])
54 parser.write(serialised_response[100:])
55
56 # Retrieve the parsed SendJoinResponse
57 parsed_response = parser.finish()
58
59 # Sanity check the parsing gave us sensible data.
60 self.assertEqual(len(parsed_response.auth_events), 1, parsed_response)
61 self.assertEqual(len(parsed_response.state), 1, parsed_response)
62 self.assertEqual(parsed_response.event_dict, {}, parsed_response)
63 self.assertIsNone(parsed_response.event, parsed_response)
7070
7171 def test_short_term_login_token_gives_user_id(self):
7272 token = self.macaroon_generator.generate_short_term_login_token(
73 self.user1, "", 5000
73 self.user1, "", duration_in_ms=5000
7474 )
7575 res = self.get_success(self.auth_handler.validate_short_term_login_token(token))
7676 self.assertEqual(self.user1, res.user_id)
9393
9494 def test_short_term_login_token_cannot_replace_user_id(self):
9595 token = self.macaroon_generator.generate_short_term_login_token(
96 self.user1, "", 5000
96 self.user1, "", duration_in_ms=5000
9797 )
9898 macaroon = pymacaroons.Macaroon.deserialize(token)
9999
212212
213213 def _get_macaroon(self):
214214 token = self.macaroon_generator.generate_short_term_login_token(
215 self.user1, "", 5000
215 self.user1, "", duration_in_ms=5000
216216 )
217217 return pymacaroons.Macaroon.deserialize(token)
6565
6666 # check that the auth handler got called as expected
6767 auth_handler.complete_sso_login.assert_called_once_with(
68 "@test_user:test", "cas", request, "redirect_uri", None, new_user=True
68 "@test_user:test",
69 "cas",
70 request,
71 "redirect_uri",
72 None,
73 new_user=True,
74 auth_provider_session_id=None,
6975 )
7076
7177 def test_map_cas_user_to_existing_user(self):
8894
8995 # check that the auth handler got called as expected
9096 auth_handler.complete_sso_login.assert_called_once_with(
91 "@test_user:test", "cas", request, "redirect_uri", None, new_user=False
97 "@test_user:test",
98 "cas",
99 request,
100 "redirect_uri",
101 None,
102 new_user=False,
103 auth_provider_session_id=None,
92104 )
93105
94106 # Subsequent calls should map to the same mxid.
97109 self.handler._handle_cas_response(request, cas_response, "redirect_uri", "")
98110 )
99111 auth_handler.complete_sso_login.assert_called_once_with(
100 "@test_user:test", "cas", request, "redirect_uri", None, new_user=False
112 "@test_user:test",
113 "cas",
114 request,
115 "redirect_uri",
116 None,
117 new_user=False,
118 auth_provider_session_id=None,
101119 )
102120
103121 def test_map_cas_user_to_invalid_localpart(self):
115133
116134 # check that the auth handler got called as expected
117135 auth_handler.complete_sso_login.assert_called_once_with(
118 "@f=c3=b6=c3=b6:test", "cas", request, "redirect_uri", None, new_user=True
136 "@f=c3=b6=c3=b6:test",
137 "cas",
138 request,
139 "redirect_uri",
140 None,
141 new_user=True,
142 auth_provider_session_id=None,
119143 )
120144
121145 @override_config(
159183
160184 # check that the auth handler got called as expected
161185 auth_handler.complete_sso_login.assert_called_once_with(
162 "@test_user:test", "cas", request, "redirect_uri", None, new_user=True
186 "@test_user:test",
187 "cas",
188 request,
189 "redirect_uri",
190 None,
191 new_user=True,
192 auth_provider_session_id=None,
163193 )
164194
165195
251251 with patch.object(self.provider, "load_metadata", patched_load_metadata):
252252 self.get_failure(self.provider.load_jwks(force=True), RuntimeError)
253253
254 # Return empty key set if JWKS are not used
255 self.provider._scopes = [] # not asking the openid scope
256 self.http_client.get_json.reset_mock()
257 jwks = self.get_success(self.provider.load_jwks(force=True))
258 self.http_client.get_json.assert_not_called()
259 self.assertEqual(jwks, {"keys": []})
260
261254 @override_config({"oidc_config": DEFAULT_CONFIG})
262255 def test_validate_config(self):
263256 """Provider metadatas are extensively validated."""
454447 self.get_success(self.handler.handle_oidc_callback(request))
455448
456449 auth_handler.complete_sso_login.assert_called_once_with(
457 expected_user_id, "oidc", request, client_redirect_url, None, new_user=True
450 expected_user_id,
451 "oidc",
452 request,
453 client_redirect_url,
454 None,
455 new_user=True,
456 auth_provider_session_id=None,
458457 )
459458 self.provider._exchange_code.assert_called_once_with(code)
460459 self.provider._parse_id_token.assert_called_once_with(token, nonce=nonce)
481480 self.provider._fetch_userinfo.reset_mock()
482481
483482 # With userinfo fetching
484 self.provider._scopes = [] # do not ask the "openid" scope
483 self.provider._user_profile_method = "userinfo_endpoint"
484 token = {
485 "type": "bearer",
486 "access_token": "access_token",
487 }
488 self.provider._exchange_code = simple_async_mock(return_value=token)
485489 self.get_success(self.handler.handle_oidc_callback(request))
486490
487491 auth_handler.complete_sso_login.assert_called_once_with(
488 expected_user_id, "oidc", request, client_redirect_url, None, new_user=False
492 expected_user_id,
493 "oidc",
494 request,
495 client_redirect_url,
496 None,
497 new_user=False,
498 auth_provider_session_id=None,
489499 )
490500 self.provider._exchange_code.assert_called_once_with(code)
491501 self.provider._parse_id_token.assert_not_called()
502 self.provider._fetch_userinfo.assert_called_once_with(token)
503 self.render_error.assert_not_called()
504
505 # With an ID token, userinfo fetching and sid in the ID token
506 self.provider._user_profile_method = "userinfo_endpoint"
507 token = {
508 "type": "bearer",
509 "access_token": "access_token",
510 "id_token": "id_token",
511 }
512 id_token = {
513 "sid": "abcdefgh",
514 }
515 self.provider._parse_id_token = simple_async_mock(return_value=id_token)
516 self.provider._exchange_code = simple_async_mock(return_value=token)
517 auth_handler.complete_sso_login.reset_mock()
518 self.provider._fetch_userinfo.reset_mock()
519 self.get_success(self.handler.handle_oidc_callback(request))
520
521 auth_handler.complete_sso_login.assert_called_once_with(
522 expected_user_id,
523 "oidc",
524 request,
525 client_redirect_url,
526 None,
527 new_user=False,
528 auth_provider_session_id=id_token["sid"],
529 )
530 self.provider._exchange_code.assert_called_once_with(code)
531 self.provider._parse_id_token.assert_called_once_with(token, nonce=nonce)
492532 self.provider._fetch_userinfo.assert_called_once_with(token)
493533 self.render_error.assert_not_called()
494534
775815 client_redirect_url,
776816 {"phone": "1234567"},
777817 new_user=True,
818 auth_provider_session_id=None,
778819 )
779820
780821 @override_config({"oidc_config": DEFAULT_CONFIG})
789830 }
790831 self.get_success(_make_callback_with_userinfo(self.hs, userinfo))
791832 auth_handler.complete_sso_login.assert_called_once_with(
792 "@test_user:test", "oidc", ANY, ANY, None, new_user=True
833 "@test_user:test",
834 "oidc",
835 ANY,
836 ANY,
837 None,
838 new_user=True,
839 auth_provider_session_id=None,
793840 )
794841 auth_handler.complete_sso_login.reset_mock()
795842
800847 }
801848 self.get_success(_make_callback_with_userinfo(self.hs, userinfo))
802849 auth_handler.complete_sso_login.assert_called_once_with(
803 "@test_user_2:test", "oidc", ANY, ANY, None, new_user=True
850 "@test_user_2:test",
851 "oidc",
852 ANY,
853 ANY,
854 None,
855 new_user=True,
856 auth_provider_session_id=None,
804857 )
805858 auth_handler.complete_sso_login.reset_mock()
806859
837890 }
838891 self.get_success(_make_callback_with_userinfo(self.hs, userinfo))
839892 auth_handler.complete_sso_login.assert_called_once_with(
840 user.to_string(), "oidc", ANY, ANY, None, new_user=False
893 user.to_string(),
894 "oidc",
895 ANY,
896 ANY,
897 None,
898 new_user=False,
899 auth_provider_session_id=None,
841900 )
842901 auth_handler.complete_sso_login.reset_mock()
843902
844903 # Subsequent calls should map to the same mxid.
845904 self.get_success(_make_callback_with_userinfo(self.hs, userinfo))
846905 auth_handler.complete_sso_login.assert_called_once_with(
847 user.to_string(), "oidc", ANY, ANY, None, new_user=False
906 user.to_string(),
907 "oidc",
908 ANY,
909 ANY,
910 None,
911 new_user=False,
912 auth_provider_session_id=None,
848913 )
849914 auth_handler.complete_sso_login.reset_mock()
850915
859924 }
860925 self.get_success(_make_callback_with_userinfo(self.hs, userinfo))
861926 auth_handler.complete_sso_login.assert_called_once_with(
862 user.to_string(), "oidc", ANY, ANY, None, new_user=False
927 user.to_string(),
928 "oidc",
929 ANY,
930 ANY,
931 None,
932 new_user=False,
933 auth_provider_session_id=None,
863934 )
864935 auth_handler.complete_sso_login.reset_mock()
865936
895966
896967 self.get_success(_make_callback_with_userinfo(self.hs, userinfo))
897968 auth_handler.complete_sso_login.assert_called_once_with(
898 "@TEST_USER_2:test", "oidc", ANY, ANY, None, new_user=False
969 "@TEST_USER_2:test",
970 "oidc",
971 ANY,
972 ANY,
973 None,
974 new_user=False,
975 auth_provider_session_id=None,
899976 )
900977
901978 @override_config({"oidc_config": DEFAULT_CONFIG})
9331010
9341011 # test_user is already taken, so test_user1 gets registered instead.
9351012 auth_handler.complete_sso_login.assert_called_once_with(
936 "@test_user1:test", "oidc", ANY, ANY, None, new_user=True
1013 "@test_user1:test",
1014 "oidc",
1015 ANY,
1016 ANY,
1017 None,
1018 new_user=True,
1019 auth_provider_session_id=None,
9371020 )
9381021 auth_handler.complete_sso_login.reset_mock()
9391022
10171100
10181101 # check that the auth handler got called as expected
10191102 auth_handler.complete_sso_login.assert_called_once_with(
1020 "@tester:test", "oidc", ANY, ANY, None, new_user=True
1103 "@tester:test",
1104 "oidc",
1105 ANY,
1106 ANY,
1107 None,
1108 new_user=True,
1109 auth_provider_session_id=None,
10211110 )
10221111
10231112 @override_config(
10421131
10431132 # check that the auth handler got called as expected
10441133 auth_handler.complete_sso_login.assert_called_once_with(
1045 "@tester:test", "oidc", ANY, ANY, None, new_user=True
1134 "@tester:test",
1135 "oidc",
1136 ANY,
1137 ANY,
1138 None,
1139 new_user=True,
1140 auth_provider_session_id=None,
10461141 )
10471142
10481143 @override_config(
11551250
11561251 handler = hs.get_oidc_handler()
11571252 provider = handler._providers["oidc"]
1158 provider._exchange_code = simple_async_mock(return_value={})
1253 provider._exchange_code = simple_async_mock(return_value={"id_token": ""})
11591254 provider._parse_id_token = simple_async_mock(return_value=userinfo)
11601255 provider._fetch_userinfo = simple_async_mock(return_value=userinfo)
11611256
3131 from synapse.rest import admin
3232 from synapse.rest.client import login, room
3333 from synapse.server import HomeServer
34 from synapse.types import JsonDict, UserID
34 from synapse.types import JsonDict, UserID, create_requester
3535
3636 from tests import unittest
3737
248248 self._assert_rooms(result, expected)
249249
250250 result = self.get_success(
251 self.handler.get_room_hierarchy(self.user, self.space)
251 self.handler.get_room_hierarchy(create_requester(self.user), self.space)
252252 )
253253 self._assert_hierarchy(result, expected)
254254
262262 expected = [(self.space, [self.room]), (self.room, ())]
263263 self._assert_rooms(result, expected)
264264
265 result = self.get_success(self.handler.get_room_hierarchy(user2, self.space))
265 result = self.get_success(
266 self.handler.get_room_hierarchy(create_requester(user2), self.space)
267 )
266268 self._assert_hierarchy(result, expected)
267269
268270 # If the space is made invite-only, it should no longer be viewable.
273275 tok=self.token,
274276 )
275277 self.get_failure(self.handler.get_space_summary(user2, self.space), AuthError)
276 self.get_failure(self.handler.get_room_hierarchy(user2, self.space), AuthError)
278 self.get_failure(
279 self.handler.get_room_hierarchy(create_requester(user2), self.space),
280 AuthError,
281 )
277282
278283 # If the space is made world-readable it should return a result.
279284 self.helper.send_state(
285290 result = self.get_success(self.handler.get_space_summary(user2, self.space))
286291 self._assert_rooms(result, expected)
287292
288 result = self.get_success(self.handler.get_room_hierarchy(user2, self.space))
293 result = self.get_success(
294 self.handler.get_room_hierarchy(create_requester(user2), self.space)
295 )
289296 self._assert_hierarchy(result, expected)
290297
291298 # Make it not world-readable again and confirm it results in an error.
296303 tok=self.token,
297304 )
298305 self.get_failure(self.handler.get_space_summary(user2, self.space), AuthError)
299 self.get_failure(self.handler.get_room_hierarchy(user2, self.space), AuthError)
306 self.get_failure(
307 self.handler.get_room_hierarchy(create_requester(user2), self.space),
308 AuthError,
309 )
300310
301311 # Join the space and results should be returned.
302312 self.helper.invite(self.space, targ=user2, tok=self.token)
304314 result = self.get_success(self.handler.get_space_summary(user2, self.space))
305315 self._assert_rooms(result, expected)
306316
307 result = self.get_success(self.handler.get_room_hierarchy(user2, self.space))
317 result = self.get_success(
318 self.handler.get_room_hierarchy(create_requester(user2), self.space)
319 )
308320 self._assert_hierarchy(result, expected)
309321
310322 # Attempting to view an unknown room returns the same error.
313325 AuthError,
314326 )
315327 self.get_failure(
316 self.handler.get_room_hierarchy(user2, "#not-a-space:" + self.hs.hostname),
328 self.handler.get_room_hierarchy(
329 create_requester(user2), "#not-a-space:" + self.hs.hostname
330 ),
317331 AuthError,
318332 )
319333
321335 """In-flight room hierarchy requests are deduplicated."""
322336 # Run two `get_room_hierarchy` calls up until they block.
323337 deferred1 = ensureDeferred(
324 self.handler.get_room_hierarchy(self.user, self.space)
338 self.handler.get_room_hierarchy(create_requester(self.user), self.space)
325339 )
326340 deferred2 = ensureDeferred(
327 self.handler.get_room_hierarchy(self.user, self.space)
341 self.handler.get_room_hierarchy(create_requester(self.user), self.space)
328342 )
329343
330344 # Complete the two calls.
339353
340354 # A subsequent `get_room_hierarchy` call should not reuse the result.
341355 result3 = self.get_success(
342 self.handler.get_room_hierarchy(self.user, self.space)
356 self.handler.get_room_hierarchy(create_requester(self.user), self.space)
343357 )
344358 self._assert_hierarchy(result3, expected)
345359 self.assertIsNot(result1, result3)
358372
359373 # Run two `get_room_hierarchy` calls for different users up until they block.
360374 deferred1 = ensureDeferred(
361 self.handler.get_room_hierarchy(self.user, self.space)
362 )
363 deferred2 = ensureDeferred(self.handler.get_room_hierarchy(user2, self.space))
375 self.handler.get_room_hierarchy(create_requester(self.user), self.space)
376 )
377 deferred2 = ensureDeferred(
378 self.handler.get_room_hierarchy(create_requester(user2), self.space)
379 )
364380
365381 # Complete the two calls.
366382 result1 = self.get_success(deferred1)
464480 ]
465481 self._assert_rooms(result, expected)
466482
467 result = self.get_success(self.handler.get_room_hierarchy(user2, self.space))
483 result = self.get_success(
484 self.handler.get_room_hierarchy(create_requester(user2), self.space)
485 )
468486 self._assert_hierarchy(result, expected)
469487
470488 def test_complex_space(self):
506524 self._assert_rooms(result, expected)
507525
508526 result = self.get_success(
509 self.handler.get_room_hierarchy(self.user, self.space)
527 self.handler.get_room_hierarchy(create_requester(self.user), self.space)
510528 )
511529 self._assert_hierarchy(result, expected)
512530
521539 room_ids.append(self.room)
522540
523541 result = self.get_success(
524 self.handler.get_room_hierarchy(self.user, self.space, limit=7)
542 self.handler.get_room_hierarchy(
543 create_requester(self.user), self.space, limit=7
544 )
525545 )
526546 # The result should have the space and all of the links, plus some of the
527547 # rooms and a pagination token.
533553 # Check the next page.
534554 result = self.get_success(
535555 self.handler.get_room_hierarchy(
536 self.user, self.space, limit=5, from_token=result["next_batch"]
556 create_requester(self.user),
557 self.space,
558 limit=5,
559 from_token=result["next_batch"],
537560 )
538561 )
539562 # The result should have the space and the room in it, along with a link
553576 room_ids.append(self.room)
554577
555578 result = self.get_success(
556 self.handler.get_room_hierarchy(self.user, self.space, limit=7)
579 self.handler.get_room_hierarchy(
580 create_requester(self.user), self.space, limit=7
581 )
557582 )
558583 self.assertIn("next_batch", result)
559584
560585 # Changing the room ID, suggested-only, or max-depth causes an error.
561586 self.get_failure(
562587 self.handler.get_room_hierarchy(
563 self.user, self.room, from_token=result["next_batch"]
588 create_requester(self.user), self.room, from_token=result["next_batch"]
564589 ),
565590 SynapseError,
566591 )
567592 self.get_failure(
568593 self.handler.get_room_hierarchy(
569 self.user,
594 create_requester(self.user),
570595 self.space,
571596 suggested_only=True,
572597 from_token=result["next_batch"],
575600 )
576601 self.get_failure(
577602 self.handler.get_room_hierarchy(
578 self.user, self.space, max_depth=0, from_token=result["next_batch"]
603 create_requester(self.user),
604 self.space,
605 max_depth=0,
606 from_token=result["next_batch"],
579607 ),
580608 SynapseError,
581609 )
582610
583611 # An invalid token is ignored.
584612 self.get_failure(
585 self.handler.get_room_hierarchy(self.user, self.space, from_token="foo"),
613 self.handler.get_room_hierarchy(
614 create_requester(self.user), self.space, from_token="foo"
615 ),
586616 SynapseError,
587617 )
588618
608638
609639 # Test just the space itself.
610640 result = self.get_success(
611 self.handler.get_room_hierarchy(self.user, self.space, max_depth=0)
641 self.handler.get_room_hierarchy(
642 create_requester(self.user), self.space, max_depth=0
643 )
612644 )
613645 expected: List[Tuple[str, Iterable[str]]] = [(spaces[0], [rooms[0], spaces[1]])]
614646 self._assert_hierarchy(result, expected)
615647
616648 # A single additional layer.
617649 result = self.get_success(
618 self.handler.get_room_hierarchy(self.user, self.space, max_depth=1)
650 self.handler.get_room_hierarchy(
651 create_requester(self.user), self.space, max_depth=1
652 )
619653 )
620654 expected += [
621655 (rooms[0], ()),
625659
626660 # A few layers.
627661 result = self.get_success(
628 self.handler.get_room_hierarchy(self.user, self.space, max_depth=3)
662 self.handler.get_room_hierarchy(
663 create_requester(self.user), self.space, max_depth=3
664 )
629665 )
630666 expected += [
631667 (rooms[1], ()),
656692 self._assert_rooms(result, expected)
657693
658694 result = self.get_success(
659 self.handler.get_room_hierarchy(self.user, self.space)
695 self.handler.get_room_hierarchy(create_requester(self.user), self.space)
660696 )
661697 self._assert_hierarchy(result, expected)
662698
738774 new=summarize_remote_room_hierarchy,
739775 ):
740776 result = self.get_success(
741 self.handler.get_room_hierarchy(self.user, self.space)
777 self.handler.get_room_hierarchy(create_requester(self.user), self.space)
742778 )
743779 self._assert_hierarchy(result, expected)
744780
905941 new=summarize_remote_room_hierarchy,
906942 ):
907943 result = self.get_success(
908 self.handler.get_room_hierarchy(self.user, self.space)
944 self.handler.get_room_hierarchy(create_requester(self.user), self.space)
909945 )
910946 self._assert_hierarchy(result, expected)
911947
963999 new=summarize_remote_room_hierarchy,
9641000 ):
9651001 result = self.get_success(
966 self.handler.get_room_hierarchy(self.user, self.space)
1002 self.handler.get_room_hierarchy(create_requester(self.user), self.space)
9671003 )
9681004 self._assert_hierarchy(result, expected)
9691005
129129
130130 # check that the auth handler got called as expected
131131 auth_handler.complete_sso_login.assert_called_once_with(
132 "@test_user:test", "saml", request, "redirect_uri", None, new_user=True
132 "@test_user:test",
133 "saml",
134 request,
135 "redirect_uri",
136 None,
137 new_user=True,
138 auth_provider_session_id=None,
133139 )
134140
135141 @override_config({"saml2_config": {"grandfathered_mxid_source_attribute": "mxid"}})
155161
156162 # check that the auth handler got called as expected
157163 auth_handler.complete_sso_login.assert_called_once_with(
158 "@test_user:test", "saml", request, "", None, new_user=False
164 "@test_user:test",
165 "saml",
166 request,
167 "",
168 None,
169 new_user=False,
170 auth_provider_session_id=None,
159171 )
160172
161173 # Subsequent calls should map to the same mxid.
164176 self.handler._handle_authn_response(request, saml_response, "")
165177 )
166178 auth_handler.complete_sso_login.assert_called_once_with(
167 "@test_user:test", "saml", request, "", None, new_user=False
179 "@test_user:test",
180 "saml",
181 request,
182 "",
183 None,
184 new_user=False,
185 auth_provider_session_id=None,
168186 )
169187
170188 def test_map_saml_response_to_invalid_localpart(self):
212230
213231 # test_user is already taken, so test_user1 gets registered instead.
214232 auth_handler.complete_sso_login.assert_called_once_with(
215 "@test_user1:test", "saml", request, "", None, new_user=True
233 "@test_user1:test",
234 "saml",
235 request,
236 "",
237 None,
238 new_user=True,
239 auth_provider_session_id=None,
216240 )
217241 auth_handler.complete_sso_login.reset_mock()
218242
308332
309333 # check that the auth handler got called as expected
310334 auth_handler.complete_sso_login.assert_called_once_with(
311 "@test_user:test", "saml", request, "redirect_uri", None, new_user=True
335 "@test_user:test",
336 "saml",
337 request,
338 "redirect_uri",
339 None,
340 new_user=True,
341 auth_provider_session_id=None,
312342 )
313343
314344
127127 )
128128
129129 self.auth_handler = hs.get_auth_handler()
130 self.store = hs.get_datastore()
130131
131132 def test_need_validated_email(self):
132133 """Test that we can only add an email pusher if the user has validated
407408 self.hs.get_datastore().db_pool.updates._all_done = False
408409
409410 # Now let's actually drive the updates to completion
410 while not self.get_success(
411 self.hs.get_datastore().db_pool.updates.has_completed_background_updates()
412 ):
413 self.get_success(
414 self.hs.get_datastore().db_pool.updates.do_next_background_update(100),
415 by=0.1,
416 )
411 self.wait_for_background_updates()
417412
418413 # Check that all pushers with unlinked addresses were deleted
419414 pushers = self.get_success(
1616 from synapse.api.room_versions import RoomVersion
1717 from synapse.rest import admin
1818 from synapse.rest.client import login, room, sync
19 from synapse.storage.util.id_generators import MultiWriterIdGenerator
1920
2021 from tests.replication._base import BaseMultiWorkerStreamTestCase
2122 from tests.server import make_request
192193 #
193194 # Worker2's event stream position will not advance until we call
194195 # __aexit__ again.
195 actx = worker_hs2.get_datastore()._stream_id_gen.get_next()
196 worker_store2 = worker_hs2.get_datastore()
197 assert isinstance(worker_store2._stream_id_gen, MultiWriterIdGenerator)
198
199 actx = worker_store2._stream_id_gen.get_next()
196200 self.get_success(actx.__aenter__())
197201
198202 response = self.helper.send(room_id1, body="Hi!", tok=self.other_access_token)
1111 # See the License for the specific language governing permissions and
1212 # limitations under the License.
1313
14 import json
1514 import os
1615 import urllib.parse
16 from http import HTTPStatus
1717 from unittest.mock import Mock
1818
1919 from twisted.internet.defer import Deferred
4040 def test_version_string(self):
4141 channel = self.make_request("GET", self.url, shorthand=False)
4242
43 self.assertEqual(200, int(channel.result["code"]), msg=channel.result["body"])
43 self.assertEqual(HTTPStatus.OK, channel.code, msg=channel.json_body)
4444 self.assertEqual(
4545 {"server_version", "python_version"}, set(channel.json_body.keys())
4646 )
6969 content={"localpart": "test"},
7070 )
7171
72 self.assertEqual(200, int(channel.result["code"]), msg=channel.result["body"])
72 self.assertEqual(HTTPStatus.OK, channel.code, msg=channel.json_body)
7373
7474 group_id = channel.json_body["group_id"]
7575
76 self._check_group(group_id, expect_code=200)
76 self._check_group(group_id, expect_code=HTTPStatus.OK)
7777
7878 # Invite/join another user
7979
8181 channel = self.make_request(
8282 "PUT", url.encode("ascii"), access_token=self.admin_user_tok, content={}
8383 )
84 self.assertEqual(200, int(channel.result["code"]), msg=channel.result["body"])
84 self.assertEqual(HTTPStatus.OK, channel.code, msg=channel.json_body)
8585
8686 url = "/groups/%s/self/accept_invite" % (group_id,)
8787 channel = self.make_request(
8888 "PUT", url.encode("ascii"), access_token=self.other_user_token, content={}
8989 )
90 self.assertEqual(200, int(channel.result["code"]), msg=channel.result["body"])
90 self.assertEqual(HTTPStatus.OK, channel.code, msg=channel.json_body)
9191
9292 # Check other user knows they're in the group
9393 self.assertIn(group_id, self._get_groups_user_is_in(self.admin_user_tok))
102102 content={"localpart": "test"},
103103 )
104104
105 self.assertEqual(200, int(channel.result["code"]), msg=channel.result["body"])
106
107 # Check group returns 404
108 self._check_group(group_id, expect_code=404)
105 self.assertEqual(HTTPStatus.OK, channel.code, msg=channel.json_body)
106
107 # Check group returns HTTPStatus.NOT_FOUND
108 self._check_group(group_id, expect_code=HTTPStatus.NOT_FOUND)
109109
110110 # Check users don't think they're in the group
111111 self.assertNotIn(group_id, self._get_groups_user_is_in(self.admin_user_tok))
121121 "GET", url.encode("ascii"), access_token=self.admin_user_tok
122122 )
123123
124 self.assertEqual(
125 expect_code, int(channel.result["code"]), msg=channel.result["body"]
126 )
124 self.assertEqual(expect_code, channel.code, msg=channel.json_body)
127125
128126 def _get_groups_user_is_in(self, access_token):
129127 """Returns the list of groups the user is in (given their access token)"""
130128 channel = self.make_request("GET", b"/joined_groups", access_token=access_token)
131129
132 self.assertEqual(200, int(channel.result["code"]), msg=channel.result["body"])
130 self.assertEqual(HTTPStatus.OK, channel.code, msg=channel.json_body)
133131
134132 return channel.json_body["groups"]
135133
209207
210208 # Should be quarantined
211209 self.assertEqual(
212 404,
213 int(channel.code),
210 HTTPStatus.NOT_FOUND,
211 channel.code,
214212 msg=(
215 "Expected to receive a 404 on accessing quarantined media: %s"
213 "Expected to receive a HTTPStatus.NOT_FOUND on accessing quarantined media: %s"
216214 % server_and_media_id
217215 ),
218216 )
231229
232230 # Expect a forbidden error
233231 self.assertEqual(
234 403,
235 int(channel.result["code"]),
232 HTTPStatus.FORBIDDEN,
233 channel.code,
236234 msg="Expected forbidden on quarantining media as a non-admin",
237235 )
238236
246244
247245 # Expect a forbidden error
248246 self.assertEqual(
249 403,
250 int(channel.result["code"]),
247 HTTPStatus.FORBIDDEN,
248 channel.code,
251249 msg="Expected forbidden on quarantining media as a non-admin",
252250 )
253251
278276 )
279277
280278 # Should be successful
281 self.assertEqual(200, int(channel.code), msg=channel.result["body"])
279 self.assertEqual(HTTPStatus.OK, channel.code)
282280
283281 # Quarantine the media
284282 url = "/_synapse/admin/v1/media/quarantine/%s/%s" % (
291289 access_token=admin_user_tok,
292290 )
293291 self.pump(1.0)
294 self.assertEqual(200, int(channel.code), msg=channel.result["body"])
292 self.assertEqual(HTTPStatus.OK, channel.code, msg=channel.json_body)
295293
296294 # Attempt to access the media
297295 self._ensure_quarantined(admin_user_tok, server_name_and_media_id)
347345 access_token=admin_user_tok,
348346 )
349347 self.pump(1.0)
350 self.assertEqual(200, int(channel.code), msg=channel.result["body"])
351 self.assertEqual(
352 json.loads(channel.result["body"].decode("utf-8")),
353 {"num_quarantined": 2},
354 "Expected 2 quarantined items",
348 self.assertEqual(HTTPStatus.OK, channel.code, msg=channel.json_body)
349 self.assertEqual(
350 channel.json_body, {"num_quarantined": 2}, "Expected 2 quarantined items"
355351 )
356352
357353 # Convert mxc URLs to server/media_id strings
395391 access_token=admin_user_tok,
396392 )
397393 self.pump(1.0)
398 self.assertEqual(200, int(channel.result["code"]), msg=channel.result["body"])
399 self.assertEqual(
400 json.loads(channel.result["body"].decode("utf-8")),
401 {"num_quarantined": 2},
402 "Expected 2 quarantined items",
394 self.assertEqual(HTTPStatus.OK, channel.code, msg=channel.json_body)
395 self.assertEqual(
396 channel.json_body, {"num_quarantined": 2}, "Expected 2 quarantined items"
403397 )
404398
405399 # Attempt to access each piece of media
431425 url = "/_synapse/admin/v1/media/protect/%s" % (urllib.parse.quote(media_id_2),)
432426 channel = self.make_request("POST", url, access_token=admin_user_tok)
433427 self.pump(1.0)
434 self.assertEqual(200, int(channel.code), msg=channel.result["body"])
428 self.assertEqual(HTTPStatus.OK, channel.code, msg=channel.json_body)
435429
436430 # Quarantine all media by this user
437431 url = "/_synapse/admin/v1/user/%s/media/quarantine" % urllib.parse.quote(
443437 access_token=admin_user_tok,
444438 )
445439 self.pump(1.0)
446 self.assertEqual(200, int(channel.result["code"]), msg=channel.result["body"])
447 self.assertEqual(
448 json.loads(channel.result["body"].decode("utf-8")),
449 {"num_quarantined": 1},
450 "Expected 1 quarantined item",
440 self.assertEqual(HTTPStatus.OK, channel.code, msg=channel.json_body)
441 self.assertEqual(
442 channel.json_body, {"num_quarantined": 1}, "Expected 1 quarantined item"
451443 )
452444
453445 # Attempt to access each piece of media, the first should fail, the
466458
467459 # Shouldn't be quarantined
468460 self.assertEqual(
469 200,
470 int(channel.code),
461 HTTPStatus.OK,
462 channel.code,
471463 msg=(
472 "Expected to receive a 200 on accessing not-quarantined media: %s"
464 "Expected to receive a HTTPStatus.OK on accessing not-quarantined media: %s"
473465 % server_and_media_id_2
474466 ),
475467 )
498490 def test_purge_history(self):
499491 """
500492 Simple test of purge history API.
501 Test only that is is possible to call, get status 200 and purge_id.
493 Test only that is is possible to call, get status HTTPStatus.OK and purge_id.
502494 """
503495
504496 channel = self.make_request(
508500 access_token=self.admin_user_tok,
509501 )
510502
511 self.assertEqual(200, channel.code, msg=channel.json_body)
503 self.assertEqual(HTTPStatus.OK, channel.code, msg=channel.json_body)
512504 self.assertIn("purge_id", channel.json_body)
513505 purge_id = channel.json_body["purge_id"]
514506
519511 access_token=self.admin_user_tok,
520512 )
521513
522 self.assertEqual(200, channel.code, msg=channel.json_body)
514 self.assertEqual(HTTPStatus.OK, channel.code, msg=channel.json_body)
523515 self.assertEqual("complete", channel.json_body["status"])
1515
1616 from parameterized import parameterized
1717
18 from twisted.test.proto_helpers import MemoryReactor
19
1820 import synapse.rest.admin
1921 from synapse.api.errors import Codes
2022 from synapse.rest.client import login
2123 from synapse.server import HomeServer
2224 from synapse.storage.background_updates import BackgroundUpdater
25 from synapse.util import Clock
2326
2427 from tests import unittest
2528
3033 login.register_servlets,
3134 ]
3235
33 def prepare(self, reactor, clock, hs: HomeServer):
36 def prepare(self, reactor: MemoryReactor, clock: Clock, hs: HomeServer) -> None:
3437 self.store = hs.get_datastore()
3538 self.admin_user = self.register_user("admin", "pass", admin=True)
3639 self.admin_user_tok = self.login("admin", "pass")
4346 ("POST", "/_synapse/admin/v1/background_updates/start_job"),
4447 ]
4548 )
46 def test_requester_is_no_admin(self, method: str, url: str):
47 """
48 If the user is not a server admin, an error 403 is returned.
49 def test_requester_is_no_admin(self, method: str, url: str) -> None:
50 """
51 If the user is not a server admin, an error HTTPStatus.FORBIDDEN is returned.
4952 """
5053
5154 self.register_user("user", "pass", admin=False)
6164 self.assertEqual(HTTPStatus.FORBIDDEN, channel.code, msg=channel.json_body)
6265 self.assertEqual(Codes.FORBIDDEN, channel.json_body["errcode"])
6366
64 def test_invalid_parameter(self):
67 def test_invalid_parameter(self) -> None:
6568 """
6669 If parameters are invalid, an error is returned.
6770 """
8992 self.assertEqual(HTTPStatus.BAD_REQUEST, channel.code, msg=channel.json_body)
9093 self.assertEqual(Codes.UNKNOWN, channel.json_body["errcode"])
9194
92 def _register_bg_update(self):
95 def _register_bg_update(self) -> None:
9396 "Adds a bg update but doesn't start it"
9497
9598 async def _fake_update(progress, batch_size) -> int:
111114 )
112115 )
113116
114 def test_status_empty(self):
117 def test_status_empty(self) -> None:
115118 """Test the status API works."""
116119
117120 channel = self.make_request(
126129 channel.json_body, {"current_updates": {}, "enabled": True}
127130 )
128131
129 def test_status_bg_update(self):
132 def test_status_bg_update(self) -> None:
130133 """Test the status API works with a background update."""
131134
132135 # Create a new background update
134137 self._register_bg_update()
135138
136139 self.store.db_pool.updates.start_doing_background_updates()
137 self.reactor.pump([1.0, 1.0])
140 self.reactor.pump([1.0, 1.0, 1.0])
138141
139142 channel = self.make_request(
140143 "GET",
161164 },
162165 )
163166
164 def test_enabled(self):
167 def test_enabled(self) -> None:
165168 """Test the enabled API works."""
166169
167170 # Create a new background update
298301 ),
299302 ]
300303 )
301 def test_start_backround_job(self, job_name: str, updates: Collection[str]):
304 def test_start_backround_job(self, job_name: str, updates: Collection[str]) -> None:
302305 """
303306 Test that background updates add to database and be processed.
304307
340343 )
341344 )
342345
343 def test_start_backround_job_twice(self):
346 def test_start_backround_job_twice(self) -> None:
344347 """Test that add a background update twice return an error."""
345348
346349 # add job to database
1010 # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
1111 # See the License for the specific language governing permissions and
1212 # limitations under the License.
13
1413 import urllib.parse
14 from http import HTTPStatus
1515
1616 from parameterized import parameterized
17
18 from twisted.test.proto_helpers import MemoryReactor
1719
1820 import synapse.rest.admin
1921 from synapse.api.errors import Codes
2022 from synapse.rest.client import login
23 from synapse.server import HomeServer
24 from synapse.util import Clock
2125
2226 from tests import unittest
2327
2933 login.register_servlets,
3034 ]
3135
32 def prepare(self, reactor, clock, hs):
36 def prepare(self, reactor: MemoryReactor, clock: Clock, hs: HomeServer) -> None:
3337 self.handler = hs.get_device_handler()
3438
3539 self.admin_user = self.register_user("admin", "pass", admin=True)
4650 )
4751
4852 @parameterized.expand(["GET", "PUT", "DELETE"])
49 def test_no_auth(self, method: str):
53 def test_no_auth(self, method: str) -> None:
5054 """
5155 Try to get a device of an user without authentication.
5256 """
5357 channel = self.make_request(method, self.url, b"{}")
5458
55 self.assertEqual(401, int(channel.result["code"]), msg=channel.result["body"])
59 self.assertEqual(
60 HTTPStatus.UNAUTHORIZED,
61 channel.code,
62 msg=channel.json_body,
63 )
5664 self.assertEqual(Codes.MISSING_TOKEN, channel.json_body["errcode"])
5765
5866 @parameterized.expand(["GET", "PUT", "DELETE"])
59 def test_requester_is_no_admin(self, method: str):
67 def test_requester_is_no_admin(self, method: str) -> None:
6068 """
6169 If the user is not a server admin, an error is returned.
6270 """
6674 access_token=self.other_user_token,
6775 )
6876
69 self.assertEqual(403, int(channel.result["code"]), msg=channel.result["body"])
77 self.assertEqual(
78 HTTPStatus.FORBIDDEN,
79 channel.code,
80 msg=channel.json_body,
81 )
7082 self.assertEqual(Codes.FORBIDDEN, channel.json_body["errcode"])
7183
7284 @parameterized.expand(["GET", "PUT", "DELETE"])
73 def test_user_does_not_exist(self, method: str):
74 """
75 Tests that a lookup for a user that does not exist returns a 404
85 def test_user_does_not_exist(self, method: str) -> None:
86 """
87 Tests that a lookup for a user that does not exist returns a HTTPStatus.NOT_FOUND
7688 """
7789 url = (
7890 "/_synapse/admin/v2/users/@unknown_person:test/devices/%s"
8597 access_token=self.admin_user_tok,
8698 )
8799
88 self.assertEqual(404, channel.code, msg=channel.json_body)
100 self.assertEqual(HTTPStatus.NOT_FOUND, channel.code, msg=channel.json_body)
89101 self.assertEqual(Codes.NOT_FOUND, channel.json_body["errcode"])
90102
91103 @parameterized.expand(["GET", "PUT", "DELETE"])
92 def test_user_is_not_local(self, method: str):
93 """
94 Tests that a lookup for a user that is not a local returns a 400
104 def test_user_is_not_local(self, method: str) -> None:
105 """
106 Tests that a lookup for a user that is not a local returns a HTTPStatus.BAD_REQUEST
95107 """
96108 url = (
97109 "/_synapse/admin/v2/users/@unknown_person:unknown_domain/devices/%s"
104116 access_token=self.admin_user_tok,
105117 )
106118
107 self.assertEqual(400, channel.code, msg=channel.json_body)
119 self.assertEqual(HTTPStatus.BAD_REQUEST, channel.code, msg=channel.json_body)
108120 self.assertEqual("Can only lookup local users", channel.json_body["error"])
109121
110 def test_unknown_device(self):
111 """
112 Tests that a lookup for a device that does not exist returns either 404 or 200.
122 def test_unknown_device(self) -> None:
123 """
124 Tests that a lookup for a device that does not exist returns either HTTPStatus.NOT_FOUND or HTTPStatus.OK.
113125 """
114126 url = "/_synapse/admin/v2/users/%s/devices/unknown_device" % urllib.parse.quote(
115127 self.other_user
121133 access_token=self.admin_user_tok,
122134 )
123135
124 self.assertEqual(404, channel.code, msg=channel.json_body)
136 self.assertEqual(HTTPStatus.NOT_FOUND, channel.code, msg=channel.json_body)
125137 self.assertEqual(Codes.NOT_FOUND, channel.json_body["errcode"])
126138
127139 channel = self.make_request(
130142 access_token=self.admin_user_tok,
131143 )
132144
133 self.assertEqual(200, channel.code, msg=channel.json_body)
145 self.assertEqual(HTTPStatus.OK, channel.code, msg=channel.json_body)
134146
135147 channel = self.make_request(
136148 "DELETE",
138150 access_token=self.admin_user_tok,
139151 )
140152
141 # Delete unknown device returns status 200
142 self.assertEqual(200, channel.code, msg=channel.json_body)
143
144 def test_update_device_too_long_display_name(self):
153 # Delete unknown device returns status HTTPStatus.OK
154 self.assertEqual(HTTPStatus.OK, channel.code, msg=channel.json_body)
155
156 def test_update_device_too_long_display_name(self) -> None:
145157 """
146158 Update a device with a display name that is invalid (too long).
147159 """
166178 content=update,
167179 )
168180
169 self.assertEqual(400, channel.code, msg=channel.json_body)
181 self.assertEqual(HTTPStatus.BAD_REQUEST, channel.code, msg=channel.json_body)
170182 self.assertEqual(Codes.TOO_LARGE, channel.json_body["errcode"])
171183
172184 # Ensure the display name was not updated.
176188 access_token=self.admin_user_tok,
177189 )
178190
179 self.assertEqual(200, channel.code, msg=channel.json_body)
191 self.assertEqual(HTTPStatus.OK, channel.code, msg=channel.json_body)
180192 self.assertEqual("new display", channel.json_body["display_name"])
181193
182 def test_update_no_display_name(self):
183 """
184 Tests that a update for a device without JSON returns a 200
194 def test_update_no_display_name(self) -> None:
195 """
196 Tests that a update for a device without JSON returns a HTTPStatus.OK
185197 """
186198 # Set iniital display name.
187199 update = {"display_name": "new display"}
197209 access_token=self.admin_user_tok,
198210 )
199211
200 self.assertEqual(200, channel.code, msg=channel.json_body)
212 self.assertEqual(HTTPStatus.OK, channel.code, msg=channel.json_body)
201213
202214 # Ensure the display name was not updated.
203215 channel = self.make_request(
206218 access_token=self.admin_user_tok,
207219 )
208220
209 self.assertEqual(200, channel.code, msg=channel.json_body)
221 self.assertEqual(HTTPStatus.OK, channel.code, msg=channel.json_body)
210222 self.assertEqual("new display", channel.json_body["display_name"])
211223
212 def test_update_display_name(self):
224 def test_update_display_name(self) -> None:
213225 """
214226 Tests a normal successful update of display name
215227 """
221233 content={"display_name": "new displayname"},
222234 )
223235
224 self.assertEqual(200, channel.code, msg=channel.json_body)
236 self.assertEqual(HTTPStatus.OK, channel.code, msg=channel.json_body)
225237
226238 # Check new display_name
227239 channel = self.make_request(
230242 access_token=self.admin_user_tok,
231243 )
232244
233 self.assertEqual(200, channel.code, msg=channel.json_body)
245 self.assertEqual(HTTPStatus.OK, channel.code, msg=channel.json_body)
234246 self.assertEqual("new displayname", channel.json_body["display_name"])
235247
236 def test_get_device(self):
248 def test_get_device(self) -> None:
237249 """
238250 Tests that a normal lookup for a device is successfully
239251 """
243255 access_token=self.admin_user_tok,
244256 )
245257
246 self.assertEqual(200, channel.code, msg=channel.json_body)
258 self.assertEqual(HTTPStatus.OK, channel.code, msg=channel.json_body)
247259 self.assertEqual(self.other_user, channel.json_body["user_id"])
248260 # Check that all fields are available
249261 self.assertIn("user_id", channel.json_body)
252264 self.assertIn("last_seen_ip", channel.json_body)
253265 self.assertIn("last_seen_ts", channel.json_body)
254266
255 def test_delete_device(self):
267 def test_delete_device(self) -> None:
256268 """
257269 Tests that a remove of a device is successfully
258270 """
268280 access_token=self.admin_user_tok,
269281 )
270282
271 self.assertEqual(200, channel.code, msg=channel.json_body)
283 self.assertEqual(HTTPStatus.OK, channel.code, msg=channel.json_body)
272284
273285 # Ensure that the number of devices is decreased
274286 res = self.get_success(self.handler.get_devices_by_user(self.other_user))
282294 login.register_servlets,
283295 ]
284296
285 def prepare(self, reactor, clock, hs):
297 def prepare(self, reactor: MemoryReactor, clock: Clock, hs: HomeServer) -> None:
286298 self.admin_user = self.register_user("admin", "pass", admin=True)
287299 self.admin_user_tok = self.login("admin", "pass")
288300
292304 self.other_user
293305 )
294306
295 def test_no_auth(self):
307 def test_no_auth(self) -> None:
296308 """
297309 Try to list devices of an user without authentication.
298310 """
299311 channel = self.make_request("GET", self.url, b"{}")
300312
301 self.assertEqual(401, int(channel.result["code"]), msg=channel.result["body"])
313 self.assertEqual(
314 HTTPStatus.UNAUTHORIZED,
315 channel.code,
316 msg=channel.json_body,
317 )
302318 self.assertEqual(Codes.MISSING_TOKEN, channel.json_body["errcode"])
303319
304 def test_requester_is_no_admin(self):
320 def test_requester_is_no_admin(self) -> None:
305321 """
306322 If the user is not a server admin, an error is returned.
307323 """
313329 access_token=other_user_token,
314330 )
315331
316 self.assertEqual(403, int(channel.result["code"]), msg=channel.result["body"])
332 self.assertEqual(
333 HTTPStatus.FORBIDDEN,
334 channel.code,
335 msg=channel.json_body,
336 )
317337 self.assertEqual(Codes.FORBIDDEN, channel.json_body["errcode"])
318338
319 def test_user_does_not_exist(self):
320 """
321 Tests that a lookup for a user that does not exist returns a 404
339 def test_user_does_not_exist(self) -> None:
340 """
341 Tests that a lookup for a user that does not exist returns a HTTPStatus.NOT_FOUND
322342 """
323343 url = "/_synapse/admin/v2/users/@unknown_person:test/devices"
324344 channel = self.make_request(
327347 access_token=self.admin_user_tok,
328348 )
329349
330 self.assertEqual(404, channel.code, msg=channel.json_body)
350 self.assertEqual(HTTPStatus.NOT_FOUND, channel.code, msg=channel.json_body)
331351 self.assertEqual(Codes.NOT_FOUND, channel.json_body["errcode"])
332352
333 def test_user_is_not_local(self):
334 """
335 Tests that a lookup for a user that is not a local returns a 400
353 def test_user_is_not_local(self) -> None:
354 """
355 Tests that a lookup for a user that is not a local returns a HTTPStatus.BAD_REQUEST
336356 """
337357 url = "/_synapse/admin/v2/users/@unknown_person:unknown_domain/devices"
338358
342362 access_token=self.admin_user_tok,
343363 )
344364
345 self.assertEqual(400, channel.code, msg=channel.json_body)
365 self.assertEqual(HTTPStatus.BAD_REQUEST, channel.code, msg=channel.json_body)
346366 self.assertEqual("Can only lookup local users", channel.json_body["error"])
347367
348 def test_user_has_no_devices(self):
368 def test_user_has_no_devices(self) -> None:
349369 """
350370 Tests that a normal lookup for devices is successfully
351371 if user has no devices
358378 access_token=self.admin_user_tok,
359379 )
360380
361 self.assertEqual(200, channel.code, msg=channel.json_body)
381 self.assertEqual(HTTPStatus.OK, channel.code, msg=channel.json_body)
362382 self.assertEqual(0, channel.json_body["total"])
363383 self.assertEqual(0, len(channel.json_body["devices"]))
364384
365 def test_get_devices(self):
385 def test_get_devices(self) -> None:
366386 """
367387 Tests that a normal lookup for devices is successfully
368388 """
378398 access_token=self.admin_user_tok,
379399 )
380400
381 self.assertEqual(200, channel.code, msg=channel.json_body)
401 self.assertEqual(HTTPStatus.OK, channel.code, msg=channel.json_body)
382402 self.assertEqual(number_devices, channel.json_body["total"])
383403 self.assertEqual(number_devices, len(channel.json_body["devices"]))
384404 self.assertEqual(self.other_user, channel.json_body["devices"][0]["user_id"])
398418 login.register_servlets,
399419 ]
400420
401 def prepare(self, reactor, clock, hs):
421 def prepare(self, reactor: MemoryReactor, clock: Clock, hs: HomeServer) -> None:
402422 self.handler = hs.get_device_handler()
403423
404424 self.admin_user = self.register_user("admin", "pass", admin=True)
410430 self.other_user
411431 )
412432
413 def test_no_auth(self):
433 def test_no_auth(self) -> None:
414434 """
415435 Try to delete devices of an user without authentication.
416436 """
417437 channel = self.make_request("POST", self.url, b"{}")
418438
419 self.assertEqual(401, int(channel.result["code"]), msg=channel.result["body"])
439 self.assertEqual(
440 HTTPStatus.UNAUTHORIZED,
441 channel.code,
442 msg=channel.json_body,
443 )
420444 self.assertEqual(Codes.MISSING_TOKEN, channel.json_body["errcode"])
421445
422 def test_requester_is_no_admin(self):
446 def test_requester_is_no_admin(self) -> None:
423447 """
424448 If the user is not a server admin, an error is returned.
425449 """
431455 access_token=other_user_token,
432456 )
433457
434 self.assertEqual(403, int(channel.result["code"]), msg=channel.result["body"])
458 self.assertEqual(
459 HTTPStatus.FORBIDDEN,
460 channel.code,
461 msg=channel.json_body,
462 )
435463 self.assertEqual(Codes.FORBIDDEN, channel.json_body["errcode"])
436464
437 def test_user_does_not_exist(self):
438 """
439 Tests that a lookup for a user that does not exist returns a 404
465 def test_user_does_not_exist(self) -> None:
466 """
467 Tests that a lookup for a user that does not exist returns a HTTPStatus.NOT_FOUND
440468 """
441469 url = "/_synapse/admin/v2/users/@unknown_person:test/delete_devices"
442470 channel = self.make_request(
445473 access_token=self.admin_user_tok,
446474 )
447475
448 self.assertEqual(404, channel.code, msg=channel.json_body)
476 self.assertEqual(HTTPStatus.NOT_FOUND, channel.code, msg=channel.json_body)
449477 self.assertEqual(Codes.NOT_FOUND, channel.json_body["errcode"])
450478
451 def test_user_is_not_local(self):
452 """
453 Tests that a lookup for a user that is not a local returns a 400
479 def test_user_is_not_local(self) -> None:
480 """
481 Tests that a lookup for a user that is not a local returns a HTTPStatus.BAD_REQUEST
454482 """
455483 url = "/_synapse/admin/v2/users/@unknown_person:unknown_domain/delete_devices"
456484
460488 access_token=self.admin_user_tok,
461489 )
462490
463 self.assertEqual(400, channel.code, msg=channel.json_body)
491 self.assertEqual(HTTPStatus.BAD_REQUEST, channel.code, msg=channel.json_body)
464492 self.assertEqual("Can only lookup local users", channel.json_body["error"])
465493
466 def test_unknown_devices(self):
467 """
468 Tests that a remove of a device that does not exist returns 200.
494 def test_unknown_devices(self) -> None:
495 """
496 Tests that a remove of a device that does not exist returns HTTPStatus.OK.
469497 """
470498 channel = self.make_request(
471499 "POST",
474502 content={"devices": ["unknown_device1", "unknown_device2"]},
475503 )
476504
477 # Delete unknown devices returns status 200
478 self.assertEqual(200, channel.code, msg=channel.json_body)
479
480 def test_delete_devices(self):
505 # Delete unknown devices returns status HTTPStatus.OK
506 self.assertEqual(HTTPStatus.OK, channel.code, msg=channel.json_body)
507
508 def test_delete_devices(self) -> None:
481509 """
482510 Tests that a remove of devices is successfully
483511 """
504532 content={"devices": device_ids},
505533 )
506534
507 self.assertEqual(200, channel.code, msg=channel.json_body)
535 self.assertEqual(HTTPStatus.OK, channel.code, msg=channel.json_body)
508536
509537 res = self.get_success(self.handler.get_devices_by_user(self.other_user))
510538 self.assertEqual(0, len(res))
1010 # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
1111 # See the License for the specific language governing permissions and
1212 # limitations under the License.
13
14 import json
13 from http import HTTPStatus
14 from typing import List
15
16 from twisted.test.proto_helpers import MemoryReactor
1517
1618 import synapse.rest.admin
1719 from synapse.api.errors import Codes
1820 from synapse.rest.client import login, report_event, room
21 from synapse.server import HomeServer
22 from synapse.types import JsonDict
23 from synapse.util import Clock
1924
2025 from tests import unittest
2126
2833 report_event.register_servlets,
2934 ]
3035
31 def prepare(self, reactor, clock, hs):
36 def prepare(self, reactor: MemoryReactor, clock: Clock, hs: HomeServer) -> None:
3237 self.admin_user = self.register_user("admin", "pass", admin=True)
3338 self.admin_user_tok = self.login("admin", "pass")
3439
6974
7075 self.url = "/_synapse/admin/v1/event_reports"
7176
72 def test_no_auth(self):
77 def test_no_auth(self) -> None:
7378 """
7479 Try to get an event report without authentication.
7580 """
7681 channel = self.make_request("GET", self.url, b"{}")
7782
78 self.assertEqual(401, int(channel.result["code"]), msg=channel.result["body"])
83 self.assertEqual(
84 HTTPStatus.UNAUTHORIZED,
85 channel.code,
86 msg=channel.json_body,
87 )
7988 self.assertEqual(Codes.MISSING_TOKEN, channel.json_body["errcode"])
8089
81 def test_requester_is_no_admin(self):
82 """
83 If the user is not a server admin, an error 403 is returned.
90 def test_requester_is_no_admin(self) -> None:
91 """
92 If the user is not a server admin, an error HTTPStatus.FORBIDDEN is returned.
8493 """
8594
8695 channel = self.make_request(
8998 access_token=self.other_user_tok,
9099 )
91100
92 self.assertEqual(403, int(channel.result["code"]), msg=channel.result["body"])
101 self.assertEqual(
102 HTTPStatus.FORBIDDEN,
103 channel.code,
104 msg=channel.json_body,
105 )
93106 self.assertEqual(Codes.FORBIDDEN, channel.json_body["errcode"])
94107
95 def test_default_success(self):
108 def test_default_success(self) -> None:
96109 """
97110 Testing list of reported events
98111 """
103116 access_token=self.admin_user_tok,
104117 )
105118
106 self.assertEqual(200, int(channel.result["code"]), msg=channel.result["body"])
119 self.assertEqual(HTTPStatus.OK, channel.code, msg=channel.json_body)
107120 self.assertEqual(channel.json_body["total"], 20)
108121 self.assertEqual(len(channel.json_body["event_reports"]), 20)
109122 self.assertNotIn("next_token", channel.json_body)
110123 self._check_fields(channel.json_body["event_reports"])
111124
112 def test_limit(self):
125 def test_limit(self) -> None:
113126 """
114127 Testing list of reported events with limit
115128 """
120133 access_token=self.admin_user_tok,
121134 )
122135
123 self.assertEqual(200, int(channel.result["code"]), msg=channel.result["body"])
136 self.assertEqual(HTTPStatus.OK, channel.code, msg=channel.json_body)
124137 self.assertEqual(channel.json_body["total"], 20)
125138 self.assertEqual(len(channel.json_body["event_reports"]), 5)
126139 self.assertEqual(channel.json_body["next_token"], 5)
127140 self._check_fields(channel.json_body["event_reports"])
128141
129 def test_from(self):
142 def test_from(self) -> None:
130143 """
131144 Testing list of reported events with a defined starting point (from)
132145 """
137150 access_token=self.admin_user_tok,
138151 )
139152
140 self.assertEqual(200, int(channel.result["code"]), msg=channel.result["body"])
153 self.assertEqual(HTTPStatus.OK, channel.code, msg=channel.json_body)
141154 self.assertEqual(channel.json_body["total"], 20)
142155 self.assertEqual(len(channel.json_body["event_reports"]), 15)
143156 self.assertNotIn("next_token", channel.json_body)
144157 self._check_fields(channel.json_body["event_reports"])
145158
146 def test_limit_and_from(self):
159 def test_limit_and_from(self) -> None:
147160 """
148161 Testing list of reported events with a defined starting point and limit
149162 """
154167 access_token=self.admin_user_tok,
155168 )
156169
157 self.assertEqual(200, int(channel.result["code"]), msg=channel.result["body"])
170 self.assertEqual(HTTPStatus.OK, channel.code, msg=channel.json_body)
158171 self.assertEqual(channel.json_body["total"], 20)
159172 self.assertEqual(channel.json_body["next_token"], 15)
160173 self.assertEqual(len(channel.json_body["event_reports"]), 10)
161174 self._check_fields(channel.json_body["event_reports"])
162175
163 def test_filter_room(self):
176 def test_filter_room(self) -> None:
164177 """
165178 Testing list of reported events with a filter of room
166179 """
171184 access_token=self.admin_user_tok,
172185 )
173186
174 self.assertEqual(200, int(channel.result["code"]), msg=channel.result["body"])
187 self.assertEqual(HTTPStatus.OK, channel.code, msg=channel.json_body)
175188 self.assertEqual(channel.json_body["total"], 10)
176189 self.assertEqual(len(channel.json_body["event_reports"]), 10)
177190 self.assertNotIn("next_token", channel.json_body)
180193 for report in channel.json_body["event_reports"]:
181194 self.assertEqual(report["room_id"], self.room_id1)
182195
183 def test_filter_user(self):
196 def test_filter_user(self) -> None:
184197 """
185198 Testing list of reported events with a filter of user
186199 """
191204 access_token=self.admin_user_tok,
192205 )
193206
194 self.assertEqual(200, int(channel.result["code"]), msg=channel.result["body"])
207 self.assertEqual(HTTPStatus.OK, channel.code, msg=channel.json_body)
195208 self.assertEqual(channel.json_body["total"], 10)
196209 self.assertEqual(len(channel.json_body["event_reports"]), 10)
197210 self.assertNotIn("next_token", channel.json_body)
200213 for report in channel.json_body["event_reports"]:
201214 self.assertEqual(report["user_id"], self.other_user)
202215
203 def test_filter_user_and_room(self):
216 def test_filter_user_and_room(self) -> None:
204217 """
205218 Testing list of reported events with a filter of user and room
206219 """
211224 access_token=self.admin_user_tok,
212225 )
213226
214 self.assertEqual(200, int(channel.result["code"]), msg=channel.result["body"])
227 self.assertEqual(HTTPStatus.OK, channel.code, msg=channel.json_body)
215228 self.assertEqual(channel.json_body["total"], 5)
216229 self.assertEqual(len(channel.json_body["event_reports"]), 5)
217230 self.assertNotIn("next_token", channel.json_body)
221234 self.assertEqual(report["user_id"], self.other_user)
222235 self.assertEqual(report["room_id"], self.room_id1)
223236
224 def test_valid_search_order(self):
237 def test_valid_search_order(self) -> None:
225238 """
226239 Testing search order. Order by timestamps.
227240 """
233246 access_token=self.admin_user_tok,
234247 )
235248
236 self.assertEqual(200, int(channel.result["code"]), msg=channel.result["body"])
249 self.assertEqual(HTTPStatus.OK, channel.code, msg=channel.json_body)
237250 self.assertEqual(channel.json_body["total"], 20)
238251 self.assertEqual(len(channel.json_body["event_reports"]), 20)
239252 report = 1
251264 access_token=self.admin_user_tok,
252265 )
253266
254 self.assertEqual(200, int(channel.result["code"]), msg=channel.result["body"])
267 self.assertEqual(HTTPStatus.OK, channel.code, msg=channel.json_body)
255268 self.assertEqual(channel.json_body["total"], 20)
256269 self.assertEqual(len(channel.json_body["event_reports"]), 20)
257270 report = 1
262275 )
263276 report += 1
264277
265 def test_invalid_search_order(self):
266 """
267 Testing that a invalid search order returns a 400
278 def test_invalid_search_order(self) -> None:
279 """
280 Testing that a invalid search order returns a HTTPStatus.BAD_REQUEST
268281 """
269282
270283 channel = self.make_request(
273286 access_token=self.admin_user_tok,
274287 )
275288
276 self.assertEqual(400, int(channel.result["code"]), msg=channel.result["body"])
289 self.assertEqual(
290 HTTPStatus.BAD_REQUEST,
291 channel.code,
292 msg=channel.json_body,
293 )
277294 self.assertEqual(Codes.INVALID_PARAM, channel.json_body["errcode"])
278295 self.assertEqual("Unknown direction: bar", channel.json_body["error"])
279296
280 def test_limit_is_negative(self):
281 """
282 Testing that a negative limit parameter returns a 400
297 def test_limit_is_negative(self) -> None:
298 """
299 Testing that a negative limit parameter returns a HTTPStatus.BAD_REQUEST
283300 """
284301
285302 channel = self.make_request(
288305 access_token=self.admin_user_tok,
289306 )
290307
291 self.assertEqual(400, int(channel.result["code"]), msg=channel.result["body"])
308 self.assertEqual(
309 HTTPStatus.BAD_REQUEST,
310 channel.code,
311 msg=channel.json_body,
312 )
292313 self.assertEqual(Codes.INVALID_PARAM, channel.json_body["errcode"])
293314
294 def test_from_is_negative(self):
295 """
296 Testing that a negative from parameter returns a 400
315 def test_from_is_negative(self) -> None:
316 """
317 Testing that a negative from parameter returns a HTTPStatus.BAD_REQUEST
297318 """
298319
299320 channel = self.make_request(
302323 access_token=self.admin_user_tok,
303324 )
304325
305 self.assertEqual(400, int(channel.result["code"]), msg=channel.result["body"])
326 self.assertEqual(
327 HTTPStatus.BAD_REQUEST,
328 channel.code,
329 msg=channel.json_body,
330 )
306331 self.assertEqual(Codes.INVALID_PARAM, channel.json_body["errcode"])
307332
308 def test_next_token(self):
333 def test_next_token(self) -> None:
309334 """
310335 Testing that `next_token` appears at the right place
311336 """
318343 access_token=self.admin_user_tok,
319344 )
320345
321 self.assertEqual(200, int(channel.result["code"]), msg=channel.result["body"])
346 self.assertEqual(HTTPStatus.OK, channel.code, msg=channel.json_body)
322347 self.assertEqual(channel.json_body["total"], 20)
323348 self.assertEqual(len(channel.json_body["event_reports"]), 20)
324349 self.assertNotIn("next_token", channel.json_body)
331356 access_token=self.admin_user_tok,
332357 )
333358
334 self.assertEqual(200, int(channel.result["code"]), msg=channel.result["body"])
359 self.assertEqual(HTTPStatus.OK, channel.code, msg=channel.json_body)
335360 self.assertEqual(channel.json_body["total"], 20)
336361 self.assertEqual(len(channel.json_body["event_reports"]), 20)
337362 self.assertNotIn("next_token", channel.json_body)
344369 access_token=self.admin_user_tok,
345370 )
346371
347 self.assertEqual(200, int(channel.result["code"]), msg=channel.result["body"])
372 self.assertEqual(HTTPStatus.OK, channel.code, msg=channel.json_body)
348373 self.assertEqual(channel.json_body["total"], 20)
349374 self.assertEqual(len(channel.json_body["event_reports"]), 19)
350375 self.assertEqual(channel.json_body["next_token"], 19)
358383 access_token=self.admin_user_tok,
359384 )
360385
361 self.assertEqual(200, int(channel.result["code"]), msg=channel.result["body"])
386 self.assertEqual(HTTPStatus.OK, channel.code, msg=channel.json_body)
362387 self.assertEqual(channel.json_body["total"], 20)
363388 self.assertEqual(len(channel.json_body["event_reports"]), 1)
364389 self.assertNotIn("next_token", channel.json_body)
365390
366 def _create_event_and_report(self, room_id, user_tok):
391 def _create_event_and_report(self, room_id: str, user_tok: str) -> None:
367392 """Create and report events"""
368393 resp = self.helper.send(room_id, tok=user_tok)
369394 event_id = resp["event_id"]
371396 channel = self.make_request(
372397 "POST",
373398 "rooms/%s/report/%s" % (room_id, event_id),
374 json.dumps({"score": -100, "reason": "this makes me sad"}),
399 {"score": -100, "reason": "this makes me sad"},
375400 access_token=user_tok,
376401 )
377 self.assertEqual(200, int(channel.result["code"]), msg=channel.result["body"])
378
379 def _create_event_and_report_without_parameters(self, room_id, user_tok):
402 self.assertEqual(HTTPStatus.OK, channel.code, msg=channel.json_body)
403
404 def _create_event_and_report_without_parameters(
405 self, room_id: str, user_tok: str
406 ) -> None:
380407 """Create and report an event, but omit reason and score"""
381408 resp = self.helper.send(room_id, tok=user_tok)
382409 event_id = resp["event_id"]
384411 channel = self.make_request(
385412 "POST",
386413 "rooms/%s/report/%s" % (room_id, event_id),
387 json.dumps({}),
414 {},
388415 access_token=user_tok,
389416 )
390 self.assertEqual(200, int(channel.result["code"]), msg=channel.result["body"])
391
392 def _check_fields(self, content):
417 self.assertEqual(HTTPStatus.OK, channel.code, msg=channel.json_body)
418
419 def _check_fields(self, content: List[JsonDict]) -> None:
393420 """Checks that all attributes are present in an event report"""
394421 for c in content:
395422 self.assertIn("id", c)
412439 report_event.register_servlets,
413440 ]
414441
415 def prepare(self, reactor, clock, hs):
442 def prepare(self, reactor: MemoryReactor, clock: Clock, hs: HomeServer) -> None:
416443 self.admin_user = self.register_user("admin", "pass", admin=True)
417444 self.admin_user_tok = self.login("admin", "pass")
418445
432459 # first created event report gets `id`=2
433460 self.url = "/_synapse/admin/v1/event_reports/2"
434461
435 def test_no_auth(self):
462 def test_no_auth(self) -> None:
436463 """
437464 Try to get event report without authentication.
438465 """
439466 channel = self.make_request("GET", self.url, b"{}")
440467
441 self.assertEqual(401, int(channel.result["code"]), msg=channel.result["body"])
468 self.assertEqual(
469 HTTPStatus.UNAUTHORIZED,
470 channel.code,
471 msg=channel.json_body,
472 )
442473 self.assertEqual(Codes.MISSING_TOKEN, channel.json_body["errcode"])
443474
444 def test_requester_is_no_admin(self):
445 """
446 If the user is not a server admin, an error 403 is returned.
475 def test_requester_is_no_admin(self) -> None:
476 """
477 If the user is not a server admin, an error HTTPStatus.FORBIDDEN is returned.
447478 """
448479
449480 channel = self.make_request(
452483 access_token=self.other_user_tok,
453484 )
454485
455 self.assertEqual(403, int(channel.result["code"]), msg=channel.result["body"])
486 self.assertEqual(
487 HTTPStatus.FORBIDDEN,
488 channel.code,
489 msg=channel.json_body,
490 )
456491 self.assertEqual(Codes.FORBIDDEN, channel.json_body["errcode"])
457492
458 def test_default_success(self):
493 def test_default_success(self) -> None:
459494 """
460495 Testing get a reported event
461496 """
466501 access_token=self.admin_user_tok,
467502 )
468503
469 self.assertEqual(200, int(channel.result["code"]), msg=channel.result["body"])
504 self.assertEqual(HTTPStatus.OK, channel.code, msg=channel.json_body)
470505 self._check_fields(channel.json_body)
471506
472 def test_invalid_report_id(self):
473 """
474 Testing that an invalid `report_id` returns a 400.
507 def test_invalid_report_id(self) -> None:
508 """
509 Testing that an invalid `report_id` returns a HTTPStatus.BAD_REQUEST.
475510 """
476511
477512 # `report_id` is negative
481516 access_token=self.admin_user_tok,
482517 )
483518
484 self.assertEqual(400, int(channel.result["code"]), msg=channel.result["body"])
519 self.assertEqual(
520 HTTPStatus.BAD_REQUEST,
521 channel.code,
522 msg=channel.json_body,
523 )
485524 self.assertEqual(Codes.INVALID_PARAM, channel.json_body["errcode"])
486525 self.assertEqual(
487526 "The report_id parameter must be a string representing a positive integer.",
495534 access_token=self.admin_user_tok,
496535 )
497536
498 self.assertEqual(400, int(channel.result["code"]), msg=channel.result["body"])
537 self.assertEqual(
538 HTTPStatus.BAD_REQUEST,
539 channel.code,
540 msg=channel.json_body,
541 )
499542 self.assertEqual(Codes.INVALID_PARAM, channel.json_body["errcode"])
500543 self.assertEqual(
501544 "The report_id parameter must be a string representing a positive integer.",
509552 access_token=self.admin_user_tok,
510553 )
511554
512 self.assertEqual(400, int(channel.result["code"]), msg=channel.result["body"])
555 self.assertEqual(
556 HTTPStatus.BAD_REQUEST,
557 channel.code,
558 msg=channel.json_body,
559 )
513560 self.assertEqual(Codes.INVALID_PARAM, channel.json_body["errcode"])
514561 self.assertEqual(
515562 "The report_id parameter must be a string representing a positive integer.",
516563 channel.json_body["error"],
517564 )
518565
519 def test_report_id_not_found(self):
520 """
521 Testing that a not existing `report_id` returns a 404.
566 def test_report_id_not_found(self) -> None:
567 """
568 Testing that a not existing `report_id` returns a HTTPStatus.NOT_FOUND.
522569 """
523570
524571 channel = self.make_request(
527574 access_token=self.admin_user_tok,
528575 )
529576
530 self.assertEqual(404, int(channel.result["code"]), msg=channel.result["body"])
577 self.assertEqual(
578 HTTPStatus.NOT_FOUND,
579 channel.code,
580 msg=channel.json_body,
581 )
531582 self.assertEqual(Codes.NOT_FOUND, channel.json_body["errcode"])
532583 self.assertEqual("Event report not found", channel.json_body["error"])
533584
534 def _create_event_and_report(self, room_id, user_tok):
585 def _create_event_and_report(self, room_id: str, user_tok: str) -> None:
535586 """Create and report events"""
536587 resp = self.helper.send(room_id, tok=user_tok)
537588 event_id = resp["event_id"]
539590 channel = self.make_request(
540591 "POST",
541592 "rooms/%s/report/%s" % (room_id, event_id),
542 json.dumps({"score": -100, "reason": "this makes me sad"}),
593 {"score": -100, "reason": "this makes me sad"},
543594 access_token=user_tok,
544595 )
545 self.assertEqual(200, int(channel.result["code"]), msg=channel.result["body"])
546
547 def _check_fields(self, content):
596 self.assertEqual(HTTPStatus.OK, channel.code, msg=channel.json_body)
597
598 def _check_fields(self, content: JsonDict) -> None:
548599 """Checks that all attributes are present in a event report"""
549600 self.assertIn("id", content)
550601 self.assertIn("received_ts", content)
0 # Copyright 2021 The Matrix.org Foundation C.I.C.
1 #
2 # Licensed under the Apache License, Version 2.0 (the "License");
3 # you may not use this file except in compliance with the License.
4 # You may obtain a copy of the License at
5 #
6 # http://www.apache.org/licenses/LICENSE-2.0
7 #
8 # Unless required by applicable law or agreed to in writing, software
9 # distributed under the License is distributed on an "AS IS" BASIS,
10 # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
11 # See the License for the specific language governing permissions and
12 # limitations under the License.
13 from http import HTTPStatus
14 from typing import List, Optional
15
16 from parameterized import parameterized
17
18 import synapse.rest.admin
19 from synapse.api.errors import Codes
20 from synapse.rest.client import login
21 from synapse.server import HomeServer
22 from synapse.types import JsonDict
23
24 from tests import unittest
25
26
27 class FederationTestCase(unittest.HomeserverTestCase):
28 servlets = [
29 synapse.rest.admin.register_servlets,
30 login.register_servlets,
31 ]
32
33 def prepare(self, reactor, clock, hs: HomeServer):
34 self.store = hs.get_datastore()
35 self.register_user("admin", "pass", admin=True)
36 self.admin_user_tok = self.login("admin", "pass")
37
38 self.url = "/_synapse/admin/v1/federation/destinations"
39
40 @parameterized.expand(
41 [
42 ("/_synapse/admin/v1/federation/destinations",),
43 ("/_synapse/admin/v1/federation/destinations/dummy",),
44 ]
45 )
46 def test_requester_is_no_admin(self, url: str):
47 """
48 If the user is not a server admin, an error 403 is returned.
49 """
50
51 self.register_user("user", "pass", admin=False)
52 other_user_tok = self.login("user", "pass")
53
54 channel = self.make_request(
55 "GET",
56 url,
57 content={},
58 access_token=other_user_tok,
59 )
60
61 self.assertEqual(HTTPStatus.FORBIDDEN, channel.code, msg=channel.json_body)
62 self.assertEqual(Codes.FORBIDDEN, channel.json_body["errcode"])
63
64 def test_invalid_parameter(self):
65 """
66 If parameters are invalid, an error is returned.
67 """
68
69 # negative limit
70 channel = self.make_request(
71 "GET",
72 self.url + "?limit=-5",
73 access_token=self.admin_user_tok,
74 )
75
76 self.assertEqual(HTTPStatus.BAD_REQUEST, channel.code, msg=channel.json_body)
77 self.assertEqual(Codes.INVALID_PARAM, channel.json_body["errcode"])
78
79 # negative from
80 channel = self.make_request(
81 "GET",
82 self.url + "?from=-5",
83 access_token=self.admin_user_tok,
84 )
85
86 self.assertEqual(HTTPStatus.BAD_REQUEST, channel.code, msg=channel.json_body)
87 self.assertEqual(Codes.INVALID_PARAM, channel.json_body["errcode"])
88
89 # unkown order_by
90 channel = self.make_request(
91 "GET",
92 self.url + "?order_by=bar",
93 access_token=self.admin_user_tok,
94 )
95
96 self.assertEqual(HTTPStatus.BAD_REQUEST, channel.code, msg=channel.json_body)
97 self.assertEqual(Codes.UNKNOWN, channel.json_body["errcode"])
98
99 # invalid search order
100 channel = self.make_request(
101 "GET",
102 self.url + "?dir=bar",
103 access_token=self.admin_user_tok,
104 )
105
106 self.assertEqual(HTTPStatus.BAD_REQUEST, channel.code, msg=channel.json_body)
107 self.assertEqual(Codes.UNKNOWN, channel.json_body["errcode"])
108
109 # invalid destination
110 channel = self.make_request(
111 "GET",
112 self.url + "/dummy",
113 access_token=self.admin_user_tok,
114 )
115
116 self.assertEqual(HTTPStatus.NOT_FOUND, channel.code, msg=channel.json_body)
117 self.assertEqual(Codes.NOT_FOUND, channel.json_body["errcode"])
118
119 def test_limit(self):
120 """
121 Testing list of destinations with limit
122 """
123
124 number_destinations = 20
125 self._create_destinations(number_destinations)
126
127 channel = self.make_request(
128 "GET",
129 self.url + "?limit=5",
130 access_token=self.admin_user_tok,
131 )
132
133 self.assertEqual(HTTPStatus.OK, channel.code, msg=channel.json_body)
134 self.assertEqual(channel.json_body["total"], number_destinations)
135 self.assertEqual(len(channel.json_body["destinations"]), 5)
136 self.assertEqual(channel.json_body["next_token"], "5")
137 self._check_fields(channel.json_body["destinations"])
138
139 def test_from(self):
140 """
141 Testing list of destinations with a defined starting point (from)
142 """
143
144 number_destinations = 20
145 self._create_destinations(number_destinations)
146
147 channel = self.make_request(
148 "GET",
149 self.url + "?from=5",
150 access_token=self.admin_user_tok,
151 )
152
153 self.assertEqual(HTTPStatus.OK, channel.code, msg=channel.json_body)
154 self.assertEqual(channel.json_body["total"], number_destinations)
155 self.assertEqual(len(channel.json_body["destinations"]), 15)
156 self.assertNotIn("next_token", channel.json_body)
157 self._check_fields(channel.json_body["destinations"])
158
159 def test_limit_and_from(self):
160 """
161 Testing list of destinations with a defined starting point and limit
162 """
163
164 number_destinations = 20
165 self._create_destinations(number_destinations)
166
167 channel = self.make_request(
168 "GET",
169 self.url + "?from=5&limit=10",
170 access_token=self.admin_user_tok,
171 )
172
173 self.assertEqual(HTTPStatus.OK, channel.code, msg=channel.json_body)
174 self.assertEqual(channel.json_body["total"], number_destinations)
175 self.assertEqual(channel.json_body["next_token"], "15")
176 self.assertEqual(len(channel.json_body["destinations"]), 10)
177 self._check_fields(channel.json_body["destinations"])
178
179 def test_next_token(self):
180 """
181 Testing that `next_token` appears at the right place
182 """
183
184 number_destinations = 20
185 self._create_destinations(number_destinations)
186
187 # `next_token` does not appear
188 # Number of results is the number of entries
189 channel = self.make_request(
190 "GET",
191 self.url + "?limit=20",
192 access_token=self.admin_user_tok,
193 )
194
195 self.assertEqual(HTTPStatus.OK, channel.code, msg=channel.json_body)
196 self.assertEqual(channel.json_body["total"], number_destinations)
197 self.assertEqual(len(channel.json_body["destinations"]), number_destinations)
198 self.assertNotIn("next_token", channel.json_body)
199
200 # `next_token` does not appear
201 # Number of max results is larger than the number of entries
202 channel = self.make_request(
203 "GET",
204 self.url + "?limit=21",
205 access_token=self.admin_user_tok,
206 )
207
208 self.assertEqual(HTTPStatus.OK, channel.code, msg=channel.json_body)
209 self.assertEqual(channel.json_body["total"], number_destinations)
210 self.assertEqual(len(channel.json_body["destinations"]), number_destinations)
211 self.assertNotIn("next_token", channel.json_body)
212
213 # `next_token` does appear
214 # Number of max results is smaller than the number of entries
215 channel = self.make_request(
216 "GET",
217 self.url + "?limit=19",
218 access_token=self.admin_user_tok,
219 )
220
221 self.assertEqual(HTTPStatus.OK, channel.code, msg=channel.json_body)
222 self.assertEqual(channel.json_body["total"], number_destinations)
223 self.assertEqual(len(channel.json_body["destinations"]), 19)
224 self.assertEqual(channel.json_body["next_token"], "19")
225
226 # Check
227 # Set `from` to value of `next_token` for request remaining entries
228 # `next_token` does not appear
229 channel = self.make_request(
230 "GET",
231 self.url + "?from=19",
232 access_token=self.admin_user_tok,
233 )
234
235 self.assertEqual(HTTPStatus.OK, channel.code, msg=channel.json_body)
236 self.assertEqual(channel.json_body["total"], number_destinations)
237 self.assertEqual(len(channel.json_body["destinations"]), 1)
238 self.assertNotIn("next_token", channel.json_body)
239
240 def test_list_all_destinations(self):
241 """
242 List all destinations.
243 """
244 number_destinations = 5
245 self._create_destinations(number_destinations)
246
247 channel = self.make_request(
248 "GET",
249 self.url,
250 {},
251 access_token=self.admin_user_tok,
252 )
253
254 self.assertEqual(HTTPStatus.OK, channel.code, msg=channel.json_body)
255 self.assertEqual(number_destinations, len(channel.json_body["destinations"]))
256 self.assertEqual(number_destinations, channel.json_body["total"])
257
258 # Check that all fields are available
259 self._check_fields(channel.json_body["destinations"])
260
261 def test_order_by(self):
262 """
263 Testing order list with parameter `order_by`
264 """
265
266 def _order_test(
267 expected_destination_list: List[str],
268 order_by: Optional[str],
269 dir: Optional[str] = None,
270 ):
271 """Request the list of destinations in a certain order.
272 Assert that order is what we expect
273
274 Args:
275 expected_destination_list: The list of user_id in the order
276 we expect to get back from the server
277 order_by: The type of ordering to give the server
278 dir: The direction of ordering to give the server
279 """
280
281 url = f"{self.url}?"
282 if order_by is not None:
283 url += f"order_by={order_by}&"
284 if dir is not None and dir in ("b", "f"):
285 url += f"dir={dir}"
286 channel = self.make_request(
287 "GET",
288 url,
289 access_token=self.admin_user_tok,
290 )
291 self.assertEqual(HTTPStatus.OK, channel.code, msg=channel.json_body)
292 self.assertEqual(channel.json_body["total"], len(expected_destination_list))
293
294 returned_order = [
295 row["destination"] for row in channel.json_body["destinations"]
296 ]
297 self.assertEqual(expected_destination_list, returned_order)
298 self._check_fields(channel.json_body["destinations"])
299
300 # create destinations
301 dest = [
302 ("sub-a.example.com", 100, 300, 200, 300),
303 ("sub-b.example.com", 200, 200, 100, 100),
304 ("sub-c.example.com", 300, 100, 300, 200),
305 ]
306 for (
307 destination,
308 failure_ts,
309 retry_last_ts,
310 retry_interval,
311 last_successful_stream_ordering,
312 ) in dest:
313 self.get_success(
314 self.store.set_destination_retry_timings(
315 destination, failure_ts, retry_last_ts, retry_interval
316 )
317 )
318 self.get_success(
319 self.store.set_destination_last_successful_stream_ordering(
320 destination, last_successful_stream_ordering
321 )
322 )
323
324 # order by default (destination)
325 _order_test([dest[0][0], dest[1][0], dest[2][0]], None)
326 _order_test([dest[0][0], dest[1][0], dest[2][0]], None, "f")
327 _order_test([dest[2][0], dest[1][0], dest[0][0]], None, "b")
328
329 # order by destination
330 _order_test([dest[0][0], dest[1][0], dest[2][0]], "destination")
331 _order_test([dest[0][0], dest[1][0], dest[2][0]], "destination", "f")
332 _order_test([dest[2][0], dest[1][0], dest[0][0]], "destination", "b")
333
334 # order by failure_ts
335 _order_test([dest[0][0], dest[1][0], dest[2][0]], "failure_ts")
336 _order_test([dest[0][0], dest[1][0], dest[2][0]], "failure_ts", "f")
337 _order_test([dest[2][0], dest[1][0], dest[0][0]], "failure_ts", "b")
338
339 # order by retry_last_ts
340 _order_test([dest[2][0], dest[1][0], dest[0][0]], "retry_last_ts")
341 _order_test([dest[2][0], dest[1][0], dest[0][0]], "retry_last_ts", "f")
342 _order_test([dest[0][0], dest[1][0], dest[2][0]], "retry_last_ts", "b")
343
344 # order by retry_interval
345 _order_test([dest[1][0], dest[0][0], dest[2][0]], "retry_interval")
346 _order_test([dest[1][0], dest[0][0], dest[2][0]], "retry_interval", "f")
347 _order_test([dest[2][0], dest[0][0], dest[1][0]], "retry_interval", "b")
348
349 # order by last_successful_stream_ordering
350 _order_test(
351 [dest[1][0], dest[2][0], dest[0][0]], "last_successful_stream_ordering"
352 )
353 _order_test(
354 [dest[1][0], dest[2][0], dest[0][0]], "last_successful_stream_ordering", "f"
355 )
356 _order_test(
357 [dest[0][0], dest[2][0], dest[1][0]], "last_successful_stream_ordering", "b"
358 )
359
360 def test_search_term(self):
361 """Test that searching for a destination works correctly"""
362
363 def _search_test(
364 expected_destination: Optional[str],
365 search_term: str,
366 ):
367 """Search for a destination and check that the returned destinationis a match
368
369 Args:
370 expected_destination: The room_id expected to be returned by the API.
371 Set to None to expect zero results for the search
372 search_term: The term to search for room names with
373 """
374 url = f"{self.url}?destination={search_term}"
375 channel = self.make_request(
376 "GET",
377 url.encode("ascii"),
378 access_token=self.admin_user_tok,
379 )
380 self.assertEqual(HTTPStatus.OK, channel.code, msg=channel.json_body)
381
382 # Check that destinations were returned
383 self.assertTrue("destinations" in channel.json_body)
384 self._check_fields(channel.json_body["destinations"])
385 destinations = channel.json_body["destinations"]
386
387 # Check that the expected number of destinations were returned
388 expected_destination_count = 1 if expected_destination else 0
389 self.assertEqual(len(destinations), expected_destination_count)
390 self.assertEqual(channel.json_body["total"], expected_destination_count)
391
392 if expected_destination:
393 # Check that the first returned destination is correct
394 self.assertEqual(expected_destination, destinations[0]["destination"])
395
396 number_destinations = 3
397 self._create_destinations(number_destinations)
398
399 # Test searching
400 _search_test("sub0.example.com", "0")
401 _search_test("sub0.example.com", "sub0")
402
403 _search_test("sub1.example.com", "1")
404 _search_test("sub1.example.com", "1.")
405
406 # Test case insensitive
407 _search_test("sub0.example.com", "SUB0")
408
409 _search_test(None, "foo")
410 _search_test(None, "bar")
411
412 def test_get_single_destination(self):
413 """
414 Get one specific destinations.
415 """
416 self._create_destinations(5)
417
418 channel = self.make_request(
419 "GET",
420 self.url + "/sub0.example.com",
421 access_token=self.admin_user_tok,
422 )
423
424 self.assertEqual(HTTPStatus.OK, channel.code, msg=channel.json_body)
425 self.assertEqual("sub0.example.com", channel.json_body["destination"])
426
427 # Check that all fields are available
428 # convert channel.json_body into a List
429 self._check_fields([channel.json_body])
430
431 def _create_destinations(self, number_destinations: int):
432 """Create a number of destinations
433
434 Args:
435 number_destinations: Number of destinations to be created
436 """
437 for i in range(0, number_destinations):
438 dest = f"sub{i}.example.com"
439 self.get_success(self.store.set_destination_retry_timings(dest, 50, 50, 50))
440 self.get_success(
441 self.store.set_destination_last_successful_stream_ordering(dest, 100)
442 )
443
444 def _check_fields(self, content: List[JsonDict]):
445 """Checks that the expected destination attributes are present in content
446
447 Args:
448 content: List that is checked for content
449 """
450 for c in content:
451 self.assertIn("destination", c)
452 self.assertIn("retry_last_ts", c)
453 self.assertIn("retry_interval", c)
454 self.assertIn("failure_ts", c)
455 self.assertIn("last_successful_stream_ordering", c)
1111 # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
1212 # See the License for the specific language governing permissions and
1313 # limitations under the License.
14
15 import json
1614 import os
15 from http import HTTPStatus
1716
1817 from parameterized import parameterized
18
19 from twisted.test.proto_helpers import MemoryReactor
1920
2021 import synapse.rest.admin
2122 from synapse.api.errors import Codes
2223 from synapse.rest.client import login, profile, room
2324 from synapse.rest.media.v1.filepath import MediaFilePaths
25 from synapse.server import HomeServer
26 from synapse.util import Clock
2427
2528 from tests import unittest
2629 from tests.server import FakeSite, make_request
3841 login.register_servlets,
3942 ]
4043
41 def prepare(self, reactor, clock, hs):
44 def prepare(self, reactor: MemoryReactor, clock: Clock, hs: HomeServer) -> None:
4245 self.media_repo = hs.get_media_repository_resource()
4346 self.server_name = hs.hostname
4447
4750
4851 self.filepaths = MediaFilePaths(hs.config.media.media_store_path)
4952
50 def test_no_auth(self):
53 def test_no_auth(self) -> None:
5154 """
5255 Try to delete media without authentication.
5356 """
5558
5659 channel = self.make_request("DELETE", url, b"{}")
5760
58 self.assertEqual(401, int(channel.result["code"]), msg=channel.result["body"])
61 self.assertEqual(
62 HTTPStatus.UNAUTHORIZED,
63 channel.code,
64 msg=channel.json_body,
65 )
5966 self.assertEqual(Codes.MISSING_TOKEN, channel.json_body["errcode"])
6067
61 def test_requester_is_no_admin(self):
68 def test_requester_is_no_admin(self) -> None:
6269 """
6370 If the user is not a server admin, an error is returned.
6471 """
7380 access_token=self.other_user_token,
7481 )
7582
76 self.assertEqual(403, int(channel.result["code"]), msg=channel.result["body"])
83 self.assertEqual(
84 HTTPStatus.FORBIDDEN,
85 channel.code,
86 msg=channel.json_body,
87 )
7788 self.assertEqual(Codes.FORBIDDEN, channel.json_body["errcode"])
7889
79 def test_media_does_not_exist(self):
80 """
81 Tests that a lookup for a media that does not exist returns a 404
90 def test_media_does_not_exist(self) -> None:
91 """
92 Tests that a lookup for a media that does not exist returns a HTTPStatus.NOT_FOUND
8293 """
8394 url = "/_synapse/admin/v1/media/%s/%s" % (self.server_name, "12345")
8495
8899 access_token=self.admin_user_tok,
89100 )
90101
91 self.assertEqual(404, channel.code, msg=channel.json_body)
102 self.assertEqual(HTTPStatus.NOT_FOUND, channel.code, msg=channel.json_body)
92103 self.assertEqual(Codes.NOT_FOUND, channel.json_body["errcode"])
93104
94 def test_media_is_not_local(self):
95 """
96 Tests that a lookup for a media that is not a local returns a 400
105 def test_media_is_not_local(self) -> None:
106 """
107 Tests that a lookup for a media that is not a local returns a HTTPStatus.BAD_REQUEST
97108 """
98109 url = "/_synapse/admin/v1/media/%s/%s" % ("unknown_domain", "12345")
99110
103114 access_token=self.admin_user_tok,
104115 )
105116
106 self.assertEqual(400, channel.code, msg=channel.json_body)
117 self.assertEqual(HTTPStatus.BAD_REQUEST, channel.code, msg=channel.json_body)
107118 self.assertEqual("Can only delete local media", channel.json_body["error"])
108119
109 def test_delete_media(self):
120 def test_delete_media(self) -> None:
110121 """
111122 Tests that delete a media is successfully
112123 """
116127
117128 # Upload some media into the room
118129 response = self.helper.upload_media(
119 upload_resource, SMALL_PNG, tok=self.admin_user_tok, expect_code=200
130 upload_resource,
131 SMALL_PNG,
132 tok=self.admin_user_tok,
133 expect_code=HTTPStatus.OK,
120134 )
121135 # Extract media ID from the response
122136 server_and_media_id = response["content_uri"][6:] # Cut off 'mxc://'
136150
137151 # Should be successful
138152 self.assertEqual(
139 200,
153 HTTPStatus.OK,
140154 channel.code,
141155 msg=(
142 "Expected to receive a 200 on accessing media: %s" % server_and_media_id
156 "Expected to receive a HTTPStatus.OK on accessing media: %s"
157 % server_and_media_id
143158 ),
144159 )
145160
156171 access_token=self.admin_user_tok,
157172 )
158173
159 self.assertEqual(200, channel.code, msg=channel.json_body)
174 self.assertEqual(HTTPStatus.OK, channel.code, msg=channel.json_body)
160175 self.assertEqual(1, channel.json_body["total"])
161176 self.assertEqual(
162177 media_id,
173188 access_token=self.admin_user_tok,
174189 )
175190 self.assertEqual(
176 404,
191 HTTPStatus.NOT_FOUND,
177192 channel.code,
178193 msg=(
179 "Expected to receive a 404 on accessing deleted media: %s"
194 "Expected to receive a HTTPStatus.NOT_FOUND on accessing deleted media: %s"
180195 % server_and_media_id
181196 ),
182197 )
195210 room.register_servlets,
196211 ]
197212
198 def prepare(self, reactor, clock, hs):
213 def prepare(self, reactor: MemoryReactor, clock: Clock, hs: HomeServer) -> None:
199214 self.media_repo = hs.get_media_repository_resource()
200215 self.server_name = hs.hostname
201216
208223 # Move clock up to somewhat realistic time
209224 self.reactor.advance(1000000000)
210225
211 def test_no_auth(self):
226 def test_no_auth(self) -> None:
212227 """
213228 Try to delete media without authentication.
214229 """
215230
216231 channel = self.make_request("POST", self.url, b"{}")
217232
218 self.assertEqual(401, int(channel.result["code"]), msg=channel.result["body"])
233 self.assertEqual(
234 HTTPStatus.UNAUTHORIZED,
235 channel.code,
236 msg=channel.json_body,
237 )
219238 self.assertEqual(Codes.MISSING_TOKEN, channel.json_body["errcode"])
220239
221 def test_requester_is_no_admin(self):
240 def test_requester_is_no_admin(self) -> None:
222241 """
223242 If the user is not a server admin, an error is returned.
224243 """
231250 access_token=self.other_user_token,
232251 )
233252
234 self.assertEqual(403, int(channel.result["code"]), msg=channel.result["body"])
253 self.assertEqual(
254 HTTPStatus.FORBIDDEN,
255 channel.code,
256 msg=channel.json_body,
257 )
235258 self.assertEqual(Codes.FORBIDDEN, channel.json_body["errcode"])
236259
237 def test_media_is_not_local(self):
238 """
239 Tests that a lookup for media that is not local returns a 400
260 def test_media_is_not_local(self) -> None:
261 """
262 Tests that a lookup for media that is not local returns a HTTPStatus.BAD_REQUEST
240263 """
241264 url = "/_synapse/admin/v1/media/%s/delete" % "unknown_domain"
242265
246269 access_token=self.admin_user_tok,
247270 )
248271
249 self.assertEqual(400, channel.code, msg=channel.json_body)
272 self.assertEqual(HTTPStatus.BAD_REQUEST, channel.code, msg=channel.json_body)
250273 self.assertEqual("Can only delete local media", channel.json_body["error"])
251274
252 def test_missing_parameter(self):
275 def test_missing_parameter(self) -> None:
253276 """
254277 If the parameter `before_ts` is missing, an error is returned.
255278 """
259282 access_token=self.admin_user_tok,
260283 )
261284
262 self.assertEqual(400, int(channel.result["code"]), msg=channel.result["body"])
285 self.assertEqual(
286 HTTPStatus.BAD_REQUEST,
287 channel.code,
288 msg=channel.json_body,
289 )
263290 self.assertEqual(Codes.MISSING_PARAM, channel.json_body["errcode"])
264291 self.assertEqual(
265292 "Missing integer query parameter 'before_ts'", channel.json_body["error"]
266293 )
267294
268 def test_invalid_parameter(self):
295 def test_invalid_parameter(self) -> None:
269296 """
270297 If parameters are invalid, an error is returned.
271298 """
275302 access_token=self.admin_user_tok,
276303 )
277304
278 self.assertEqual(400, int(channel.result["code"]), msg=channel.result["body"])
305 self.assertEqual(
306 HTTPStatus.BAD_REQUEST,
307 channel.code,
308 msg=channel.json_body,
309 )
279310 self.assertEqual(Codes.INVALID_PARAM, channel.json_body["errcode"])
280311 self.assertEqual(
281312 "Query parameter before_ts must be a positive integer.",
288319 access_token=self.admin_user_tok,
289320 )
290321
291 self.assertEqual(400, int(channel.result["code"]), msg=channel.result["body"])
322 self.assertEqual(
323 HTTPStatus.BAD_REQUEST,
324 channel.code,
325 msg=channel.json_body,
326 )
292327 self.assertEqual(Codes.INVALID_PARAM, channel.json_body["errcode"])
293328 self.assertEqual(
294329 "Query parameter before_ts you provided is from the year 1970. "
302337 access_token=self.admin_user_tok,
303338 )
304339
305 self.assertEqual(400, int(channel.result["code"]), msg=channel.result["body"])
340 self.assertEqual(
341 HTTPStatus.BAD_REQUEST,
342 channel.code,
343 msg=channel.json_body,
344 )
306345 self.assertEqual(Codes.INVALID_PARAM, channel.json_body["errcode"])
307346 self.assertEqual(
308347 "Query parameter size_gt must be a string representing a positive integer.",
315354 access_token=self.admin_user_tok,
316355 )
317356
318 self.assertEqual(400, int(channel.result["code"]), msg=channel.result["body"])
357 self.assertEqual(
358 HTTPStatus.BAD_REQUEST,
359 channel.code,
360 msg=channel.json_body,
361 )
319362 self.assertEqual(Codes.UNKNOWN, channel.json_body["errcode"])
320363 self.assertEqual(
321364 "Boolean query parameter 'keep_profiles' must be one of ['true', 'false']",
322365 channel.json_body["error"],
323366 )
324367
325 def test_delete_media_never_accessed(self):
368 def test_delete_media_never_accessed(self) -> None:
326369 """
327370 Tests that media deleted if it is older than `before_ts` and never accessed
328371 `last_access_ts` is `NULL` and `created_ts` < `before_ts`
344387 self.url + "?before_ts=" + str(now_ms),
345388 access_token=self.admin_user_tok,
346389 )
347 self.assertEqual(200, channel.code, msg=channel.json_body)
390 self.assertEqual(HTTPStatus.OK, channel.code, msg=channel.json_body)
348391 self.assertEqual(1, channel.json_body["total"])
349392 self.assertEqual(
350393 media_id,
353396
354397 self._access_media(server_and_media_id, False)
355398
356 def test_keep_media_by_date(self):
399 def test_keep_media_by_date(self) -> None:
357400 """
358401 Tests that media is not deleted if it is newer than `before_ts`
359402 """
369412 self.url + "?before_ts=" + str(now_ms),
370413 access_token=self.admin_user_tok,
371414 )
372 self.assertEqual(200, channel.code, msg=channel.json_body)
415 self.assertEqual(HTTPStatus.OK, channel.code, msg=channel.json_body)
373416 self.assertEqual(0, channel.json_body["total"])
374417
375418 self._access_media(server_and_media_id)
381424 self.url + "?before_ts=" + str(now_ms),
382425 access_token=self.admin_user_tok,
383426 )
384 self.assertEqual(200, channel.code, msg=channel.json_body)
427 self.assertEqual(HTTPStatus.OK, channel.code, msg=channel.json_body)
385428 self.assertEqual(1, channel.json_body["total"])
386429 self.assertEqual(
387430 server_and_media_id.split("/")[1],
390433
391434 self._access_media(server_and_media_id, False)
392435
393 def test_keep_media_by_size(self):
436 def test_keep_media_by_size(self) -> None:
394437 """
395438 Tests that media is not deleted if its size is smaller than or equal
396439 to `size_gt`
405448 self.url + "?before_ts=" + str(now_ms) + "&size_gt=67",
406449 access_token=self.admin_user_tok,
407450 )
408 self.assertEqual(200, channel.code, msg=channel.json_body)
451 self.assertEqual(HTTPStatus.OK, channel.code, msg=channel.json_body)
409452 self.assertEqual(0, channel.json_body["total"])
410453
411454 self._access_media(server_and_media_id)
416459 self.url + "?before_ts=" + str(now_ms) + "&size_gt=66",
417460 access_token=self.admin_user_tok,
418461 )
419 self.assertEqual(200, channel.code, msg=channel.json_body)
462 self.assertEqual(HTTPStatus.OK, channel.code, msg=channel.json_body)
420463 self.assertEqual(1, channel.json_body["total"])
421464 self.assertEqual(
422465 server_and_media_id.split("/")[1],
425468
426469 self._access_media(server_and_media_id, False)
427470
428 def test_keep_media_by_user_avatar(self):
471 def test_keep_media_by_user_avatar(self) -> None:
429472 """
430473 Tests that we do not delete media if is used as a user avatar
431474 Tests parameter `keep_profiles`
438481 channel = self.make_request(
439482 "PUT",
440483 "/profile/%s/avatar_url" % (self.admin_user,),
441 content=json.dumps({"avatar_url": "mxc://%s" % (server_and_media_id,)}),
442 access_token=self.admin_user_tok,
443 )
444 self.assertEqual(200, channel.code, msg=channel.json_body)
484 content={"avatar_url": "mxc://%s" % (server_and_media_id,)},
485 access_token=self.admin_user_tok,
486 )
487 self.assertEqual(HTTPStatus.OK, channel.code, msg=channel.json_body)
445488
446489 now_ms = self.clock.time_msec()
447490 channel = self.make_request(
449492 self.url + "?before_ts=" + str(now_ms) + "&keep_profiles=true",
450493 access_token=self.admin_user_tok,
451494 )
452 self.assertEqual(200, channel.code, msg=channel.json_body)
495 self.assertEqual(HTTPStatus.OK, channel.code, msg=channel.json_body)
453496 self.assertEqual(0, channel.json_body["total"])
454497
455498 self._access_media(server_and_media_id)
460503 self.url + "?before_ts=" + str(now_ms) + "&keep_profiles=false",
461504 access_token=self.admin_user_tok,
462505 )
463 self.assertEqual(200, channel.code, msg=channel.json_body)
506 self.assertEqual(HTTPStatus.OK, channel.code, msg=channel.json_body)
464507 self.assertEqual(1, channel.json_body["total"])
465508 self.assertEqual(
466509 server_and_media_id.split("/")[1],
469512
470513 self._access_media(server_and_media_id, False)
471514
472 def test_keep_media_by_room_avatar(self):
515 def test_keep_media_by_room_avatar(self) -> None:
473516 """
474517 Tests that we do not delete media if it is used as a room avatar
475518 Tests parameter `keep_profiles`
483526 channel = self.make_request(
484527 "PUT",
485528 "/rooms/%s/state/m.room.avatar" % (room_id,),
486 content=json.dumps({"url": "mxc://%s" % (server_and_media_id,)}),
487 access_token=self.admin_user_tok,
488 )
489 self.assertEqual(200, channel.code, msg=channel.json_body)
529 content={"url": "mxc://%s" % (server_and_media_id,)},
530 access_token=self.admin_user_tok,
531 )
532 self.assertEqual(HTTPStatus.OK, channel.code, msg=channel.json_body)
490533
491534 now_ms = self.clock.time_msec()
492535 channel = self.make_request(
494537 self.url + "?before_ts=" + str(now_ms) + "&keep_profiles=true",
495538 access_token=self.admin_user_tok,
496539 )
497 self.assertEqual(200, channel.code, msg=channel.json_body)
540 self.assertEqual(HTTPStatus.OK, channel.code, msg=channel.json_body)
498541 self.assertEqual(0, channel.json_body["total"])
499542
500543 self._access_media(server_and_media_id)
505548 self.url + "?before_ts=" + str(now_ms) + "&keep_profiles=false",
506549 access_token=self.admin_user_tok,
507550 )
508 self.assertEqual(200, channel.code, msg=channel.json_body)
551 self.assertEqual(HTTPStatus.OK, channel.code, msg=channel.json_body)
509552 self.assertEqual(1, channel.json_body["total"])
510553 self.assertEqual(
511554 server_and_media_id.split("/")[1],
514557
515558 self._access_media(server_and_media_id, False)
516559
517 def _create_media(self):
560 def _create_media(self) -> str:
518561 """
519562 Create a media and return media_id and server_and_media_id
520563 """
522565
523566 # Upload some media into the room
524567 response = self.helper.upload_media(
525 upload_resource, SMALL_PNG, tok=self.admin_user_tok, expect_code=200
568 upload_resource,
569 SMALL_PNG,
570 tok=self.admin_user_tok,
571 expect_code=HTTPStatus.OK,
526572 )
527573 # Extract media ID from the response
528574 server_and_media_id = response["content_uri"][6:] # Cut off 'mxc://'
533579
534580 return server_and_media_id
535581
536 def _access_media(self, server_and_media_id, expect_success=True):
582 def _access_media(self, server_and_media_id, expect_success=True) -> None:
537583 """
538584 Try to access a media and check the result
539585 """
553599
554600 if expect_success:
555601 self.assertEqual(
556 200,
602 HTTPStatus.OK,
557603 channel.code,
558604 msg=(
559 "Expected to receive a 200 on accessing media: %s"
605 "Expected to receive a HTTPStatus.OK on accessing media: %s"
560606 % server_and_media_id
561607 ),
562608 )
564610 self.assertTrue(os.path.exists(local_path))
565611 else:
566612 self.assertEqual(
567 404,
613 HTTPStatus.NOT_FOUND,
568614 channel.code,
569615 msg=(
570 "Expected to receive a 404 on accessing deleted media: %s"
616 "Expected to receive a HTTPStatus.NOT_FOUND on accessing deleted media: %s"
571617 % (server_and_media_id)
572618 ),
573619 )
583629 login.register_servlets,
584630 ]
585631
586 def prepare(self, reactor, clock, hs):
632 def prepare(self, reactor: MemoryReactor, clock: Clock, hs: HomeServer) -> None:
587633 media_repo = hs.get_media_repository_resource()
588634 self.store = hs.get_datastore()
589635 self.server_name = hs.hostname
596642
597643 # Upload some media into the room
598644 response = self.helper.upload_media(
599 upload_resource, SMALL_PNG, tok=self.admin_user_tok, expect_code=200
645 upload_resource,
646 SMALL_PNG,
647 tok=self.admin_user_tok,
648 expect_code=HTTPStatus.OK,
600649 )
601650 # Extract media ID from the response
602651 server_and_media_id = response["content_uri"][6:] # Cut off 'mxc://'
605654 self.url = "/_synapse/admin/v1/media/%s/%s/%s"
606655
607656 @parameterized.expand(["quarantine", "unquarantine"])
608 def test_no_auth(self, action: str):
657 def test_no_auth(self, action: str) -> None:
609658 """
610659 Try to protect media without authentication.
611660 """
616665 b"{}",
617666 )
618667
619 self.assertEqual(401, int(channel.result["code"]), msg=channel.result["body"])
668 self.assertEqual(
669 HTTPStatus.UNAUTHORIZED,
670 channel.code,
671 msg=channel.json_body,
672 )
620673 self.assertEqual(Codes.MISSING_TOKEN, channel.json_body["errcode"])
621674
622675 @parameterized.expand(["quarantine", "unquarantine"])
623 def test_requester_is_no_admin(self, action: str):
676 def test_requester_is_no_admin(self, action: str) -> None:
624677 """
625678 If the user is not a server admin, an error is returned.
626679 """
633686 access_token=self.other_user_token,
634687 )
635688
636 self.assertEqual(403, int(channel.result["code"]), msg=channel.result["body"])
689 self.assertEqual(
690 HTTPStatus.FORBIDDEN,
691 channel.code,
692 msg=channel.json_body,
693 )
637694 self.assertEqual(Codes.FORBIDDEN, channel.json_body["errcode"])
638695
639 def test_quarantine_media(self):
696 def test_quarantine_media(self) -> None:
640697 """
641698 Tests that quarantining and remove from quarantine a media is successfully
642699 """
651708 access_token=self.admin_user_tok,
652709 )
653710
654 self.assertEqual(200, channel.code, msg=channel.json_body)
711 self.assertEqual(HTTPStatus.OK, channel.code, msg=channel.json_body)
655712 self.assertFalse(channel.json_body)
656713
657714 media_info = self.get_success(self.store.get_local_media(self.media_id))
664721 access_token=self.admin_user_tok,
665722 )
666723
667 self.assertEqual(200, channel.code, msg=channel.json_body)
724 self.assertEqual(HTTPStatus.OK, channel.code, msg=channel.json_body)
668725 self.assertFalse(channel.json_body)
669726
670727 media_info = self.get_success(self.store.get_local_media(self.media_id))
671728 self.assertFalse(media_info["quarantined_by"])
672729
673 def test_quarantine_protected_media(self):
730 def test_quarantine_protected_media(self) -> None:
674731 """
675732 Tests that quarantining from protected media fails
676733 """
689746 access_token=self.admin_user_tok,
690747 )
691748
692 self.assertEqual(200, channel.code, msg=channel.json_body)
749 self.assertEqual(HTTPStatus.OK, channel.code, msg=channel.json_body)
693750 self.assertFalse(channel.json_body)
694751
695752 # verify that is not in quarantine
705762 login.register_servlets,
706763 ]
707764
708 def prepare(self, reactor, clock, hs):
765 def prepare(self, reactor: MemoryReactor, clock: Clock, hs: HomeServer) -> None:
709766 media_repo = hs.get_media_repository_resource()
710767 self.store = hs.get_datastore()
711768
717774
718775 # Upload some media into the room
719776 response = self.helper.upload_media(
720 upload_resource, SMALL_PNG, tok=self.admin_user_tok, expect_code=200
777 upload_resource,
778 SMALL_PNG,
779 tok=self.admin_user_tok,
780 expect_code=HTTPStatus.OK,
721781 )
722782 # Extract media ID from the response
723783 server_and_media_id = response["content_uri"][6:] # Cut off 'mxc://'
726786 self.url = "/_synapse/admin/v1/media/%s/%s"
727787
728788 @parameterized.expand(["protect", "unprotect"])
729 def test_no_auth(self, action: str):
789 def test_no_auth(self, action: str) -> None:
730790 """
731791 Try to protect media without authentication.
732792 """
733793
734794 channel = self.make_request("POST", self.url % (action, self.media_id), b"{}")
735795
736 self.assertEqual(401, int(channel.result["code"]), msg=channel.result["body"])
796 self.assertEqual(
797 HTTPStatus.UNAUTHORIZED,
798 channel.code,
799 msg=channel.json_body,
800 )
737801 self.assertEqual(Codes.MISSING_TOKEN, channel.json_body["errcode"])
738802
739803 @parameterized.expand(["protect", "unprotect"])
740 def test_requester_is_no_admin(self, action: str):
804 def test_requester_is_no_admin(self, action: str) -> None:
741805 """
742806 If the user is not a server admin, an error is returned.
743807 """
750814 access_token=self.other_user_token,
751815 )
752816
753 self.assertEqual(403, int(channel.result["code"]), msg=channel.result["body"])
817 self.assertEqual(
818 HTTPStatus.FORBIDDEN,
819 channel.code,
820 msg=channel.json_body,
821 )
754822 self.assertEqual(Codes.FORBIDDEN, channel.json_body["errcode"])
755823
756 def test_protect_media(self):
824 def test_protect_media(self) -> None:
757825 """
758826 Tests that protect and unprotect a media is successfully
759827 """
768836 access_token=self.admin_user_tok,
769837 )
770838
771 self.assertEqual(200, channel.code, msg=channel.json_body)
839 self.assertEqual(HTTPStatus.OK, channel.code, msg=channel.json_body)
772840 self.assertFalse(channel.json_body)
773841
774842 media_info = self.get_success(self.store.get_local_media(self.media_id))
781849 access_token=self.admin_user_tok,
782850 )
783851
784 self.assertEqual(200, channel.code, msg=channel.json_body)
852 self.assertEqual(HTTPStatus.OK, channel.code, msg=channel.json_body)
785853 self.assertFalse(channel.json_body)
786854
787855 media_info = self.get_success(self.store.get_local_media(self.media_id))
798866 room.register_servlets,
799867 ]
800868
801 def prepare(self, reactor, clock, hs):
869 def prepare(self, reactor: MemoryReactor, clock: Clock, hs: HomeServer) -> None:
802870 self.media_repo = hs.get_media_repository_resource()
803871 self.server_name = hs.hostname
804872
808876 self.filepaths = MediaFilePaths(hs.config.media.media_store_path)
809877 self.url = "/_synapse/admin/v1/purge_media_cache"
810878
811 def test_no_auth(self):
879 def test_no_auth(self) -> None:
812880 """
813881 Try to delete media without authentication.
814882 """
815883
816884 channel = self.make_request("POST", self.url, b"{}")
817885
818 self.assertEqual(401, int(channel.result["code"]), msg=channel.result["body"])
886 self.assertEqual(
887 HTTPStatus.UNAUTHORIZED,
888 channel.code,
889 msg=channel.json_body,
890 )
819891 self.assertEqual(Codes.MISSING_TOKEN, channel.json_body["errcode"])
820892
821 def test_requester_is_not_admin(self):
893 def test_requester_is_not_admin(self) -> None:
822894 """
823895 If the user is not a server admin, an error is returned.
824896 """
831903 access_token=self.other_user_token,
832904 )
833905
834 self.assertEqual(403, int(channel.result["code"]), msg=channel.result["body"])
906 self.assertEqual(
907 HTTPStatus.FORBIDDEN,
908 channel.code,
909 msg=channel.json_body,
910 )
835911 self.assertEqual(Codes.FORBIDDEN, channel.json_body["errcode"])
836912
837 def test_invalid_parameter(self):
913 def test_invalid_parameter(self) -> None:
838914 """
839915 If parameters are invalid, an error is returned.
840916 """
844920 access_token=self.admin_user_tok,
845921 )
846922
847 self.assertEqual(400, int(channel.result["code"]), msg=channel.result["body"])
923 self.assertEqual(
924 HTTPStatus.BAD_REQUEST,
925 channel.code,
926 msg=channel.json_body,
927 )
848928 self.assertEqual(Codes.INVALID_PARAM, channel.json_body["errcode"])
849929 self.assertEqual(
850930 "Query parameter before_ts must be a positive integer.",
857937 access_token=self.admin_user_tok,
858938 )
859939
860 self.assertEqual(400, int(channel.result["code"]), msg=channel.result["body"])
940 self.assertEqual(
941 HTTPStatus.BAD_REQUEST,
942 channel.code,
943 msg=channel.json_body,
944 )
861945 self.assertEqual(Codes.INVALID_PARAM, channel.json_body["errcode"])
862946 self.assertEqual(
863947 "Query parameter before_ts you provided is from the year 1970. "
1010 # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
1111 # See the License for the specific language governing permissions and
1212 # limitations under the License.
13
1413 import random
1514 import string
15 from http import HTTPStatus
16
17 from twisted.test.proto_helpers import MemoryReactor
1618
1719 import synapse.rest.admin
1820 from synapse.api.errors import Codes
1921 from synapse.rest.client import login
22 from synapse.server import HomeServer
23 from synapse.util import Clock
2024
2125 from tests import unittest
2226
2731 login.register_servlets,
2832 ]
2933
30 def prepare(self, reactor, clock, hs):
34 def prepare(self, reactor: MemoryReactor, clock: Clock, hs: HomeServer) -> None:
3135 self.store = hs.get_datastore()
3236 self.admin_user = self.register_user("admin", "pass", admin=True)
3337 self.admin_user_tok = self.login("admin", "pass")
3741
3842 self.url = "/_synapse/admin/v1/registration_tokens"
3943
40 def _new_token(self, **kwargs):
44 def _new_token(self, **kwargs) -> str:
4145 """Helper function to create a token."""
4246 token = kwargs.get(
4347 "token",
5963
6064 # CREATION
6165
62 def test_create_no_auth(self):
66 def test_create_no_auth(self) -> None:
6367 """Try to create a token without authentication."""
6468 channel = self.make_request("POST", self.url + "/new", {})
65 self.assertEqual(401, int(channel.result["code"]), msg=channel.result["body"])
69 self.assertEqual(
70 HTTPStatus.UNAUTHORIZED,
71 channel.code,
72 msg=channel.json_body,
73 )
6674 self.assertEqual(Codes.MISSING_TOKEN, channel.json_body["errcode"])
6775
68 def test_create_requester_not_admin(self):
76 def test_create_requester_not_admin(self) -> None:
6977 """Try to create a token while not an admin."""
7078 channel = self.make_request(
7179 "POST",
7381 {},
7482 access_token=self.other_user_tok,
7583 )
76 self.assertEqual(403, int(channel.result["code"]), msg=channel.result["body"])
84 self.assertEqual(
85 HTTPStatus.FORBIDDEN,
86 channel.code,
87 msg=channel.json_body,
88 )
7789 self.assertEqual(Codes.FORBIDDEN, channel.json_body["errcode"])
7890
79 def test_create_using_defaults(self):
91 def test_create_using_defaults(self) -> None:
8092 """Create a token using all the defaults."""
8193 channel = self.make_request(
8294 "POST",
8597 access_token=self.admin_user_tok,
8698 )
8799
88 self.assertEqual(200, int(channel.result["code"]), msg=channel.result["body"])
100 self.assertEqual(HTTPStatus.OK, channel.code, msg=channel.json_body)
89101 self.assertEqual(len(channel.json_body["token"]), 16)
90102 self.assertIsNone(channel.json_body["uses_allowed"])
91103 self.assertIsNone(channel.json_body["expiry_time"])
92104 self.assertEqual(channel.json_body["pending"], 0)
93105 self.assertEqual(channel.json_body["completed"], 0)
94106
95 def test_create_specifying_fields(self):
107 def test_create_specifying_fields(self) -> None:
96108 """Create a token specifying the value of all fields."""
97109 # As many of the allowed characters as possible with length <= 64
98110 token = "adefghijklmnopqrstuvwxyzABCDEFGHIJKLMNOPQRSTUVWXYZ0123456789._~-"
109121 access_token=self.admin_user_tok,
110122 )
111123
112 self.assertEqual(200, int(channel.result["code"]), msg=channel.result["body"])
124 self.assertEqual(HTTPStatus.OK, channel.code, msg=channel.json_body)
113125 self.assertEqual(channel.json_body["token"], token)
114126 self.assertEqual(channel.json_body["uses_allowed"], 1)
115127 self.assertEqual(channel.json_body["expiry_time"], data["expiry_time"])
116128 self.assertEqual(channel.json_body["pending"], 0)
117129 self.assertEqual(channel.json_body["completed"], 0)
118130
119 def test_create_with_null_value(self):
131 def test_create_with_null_value(self) -> None:
120132 """Create a token specifying unlimited uses and no expiry."""
121133 data = {
122134 "uses_allowed": None,
130142 access_token=self.admin_user_tok,
131143 )
132144
133 self.assertEqual(200, int(channel.result["code"]), msg=channel.result["body"])
145 self.assertEqual(HTTPStatus.OK, channel.code, msg=channel.json_body)
134146 self.assertEqual(len(channel.json_body["token"]), 16)
135147 self.assertIsNone(channel.json_body["uses_allowed"])
136148 self.assertIsNone(channel.json_body["expiry_time"])
137149 self.assertEqual(channel.json_body["pending"], 0)
138150 self.assertEqual(channel.json_body["completed"], 0)
139151
140 def test_create_token_too_long(self):
152 def test_create_token_too_long(self) -> None:
141153 """Check token longer than 64 chars is invalid."""
142154 data = {"token": "a" * 65}
143155
148160 access_token=self.admin_user_tok,
149161 )
150162
151 self.assertEqual(400, int(channel.result["code"]), msg=channel.result["body"])
152 self.assertEqual(channel.json_body["errcode"], Codes.INVALID_PARAM)
153
154 def test_create_token_invalid_chars(self):
163 self.assertEqual(
164 HTTPStatus.BAD_REQUEST,
165 channel.code,
166 msg=channel.json_body,
167 )
168 self.assertEqual(channel.json_body["errcode"], Codes.INVALID_PARAM)
169
170 def test_create_token_invalid_chars(self) -> None:
155171 """Check you can't create token with invalid characters."""
156172 data = {
157173 "token": "abc/def",
164180 access_token=self.admin_user_tok,
165181 )
166182
167 self.assertEqual(400, int(channel.result["code"]), msg=channel.result["body"])
168 self.assertEqual(channel.json_body["errcode"], Codes.INVALID_PARAM)
169
170 def test_create_token_already_exists(self):
183 self.assertEqual(
184 HTTPStatus.BAD_REQUEST,
185 channel.code,
186 msg=channel.json_body,
187 )
188 self.assertEqual(channel.json_body["errcode"], Codes.INVALID_PARAM)
189
190 def test_create_token_already_exists(self) -> None:
171191 """Check you can't create token that already exists."""
172192 data = {
173193 "token": "abcd",
179199 data,
180200 access_token=self.admin_user_tok,
181201 )
182 self.assertEqual(200, int(channel1.result["code"]), msg=channel1.result["body"])
202 self.assertEqual(HTTPStatus.OK, channel1.code, msg=channel1.json_body)
183203
184204 channel2 = self.make_request(
185205 "POST",
187207 data,
188208 access_token=self.admin_user_tok,
189209 )
190 self.assertEqual(400, int(channel2.result["code"]), msg=channel2.result["body"])
210 self.assertEqual(HTTPStatus.BAD_REQUEST, channel2.code, msg=channel2.json_body)
191211 self.assertEqual(channel2.json_body["errcode"], Codes.INVALID_PARAM)
192212
193 def test_create_unable_to_generate_token(self):
213 def test_create_unable_to_generate_token(self) -> None:
194214 """Check right error is raised when server can't generate unique token."""
195215 # Create all possible single character tokens
196216 tokens = []
219239 {"length": 1},
220240 access_token=self.admin_user_tok,
221241 )
222 self.assertEqual(500, int(channel.result["code"]), msg=channel.result["body"])
223
224 def test_create_uses_allowed(self):
242 self.assertEqual(500, channel.code, msg=channel.json_body)
243
244 def test_create_uses_allowed(self) -> None:
225245 """Check you can only create a token with good values for uses_allowed."""
226246 # Should work with 0 (token is invalid from the start)
227247 channel = self.make_request(
230250 {"uses_allowed": 0},
231251 access_token=self.admin_user_tok,
232252 )
233 self.assertEqual(200, int(channel.result["code"]), msg=channel.result["body"])
253 self.assertEqual(HTTPStatus.OK, channel.code, msg=channel.json_body)
234254 self.assertEqual(channel.json_body["uses_allowed"], 0)
235255
236256 # Should fail with negative integer
240260 {"uses_allowed": -5},
241261 access_token=self.admin_user_tok,
242262 )
243 self.assertEqual(400, int(channel.result["code"]), msg=channel.result["body"])
263 self.assertEqual(
264 HTTPStatus.BAD_REQUEST,
265 channel.code,
266 msg=channel.json_body,
267 )
244268 self.assertEqual(channel.json_body["errcode"], Codes.INVALID_PARAM)
245269
246270 # Should fail with float
250274 {"uses_allowed": 1.5},
251275 access_token=self.admin_user_tok,
252276 )
253 self.assertEqual(400, int(channel.result["code"]), msg=channel.result["body"])
254 self.assertEqual(channel.json_body["errcode"], Codes.INVALID_PARAM)
255
256 def test_create_expiry_time(self):
277 self.assertEqual(
278 HTTPStatus.BAD_REQUEST,
279 channel.code,
280 msg=channel.json_body,
281 )
282 self.assertEqual(channel.json_body["errcode"], Codes.INVALID_PARAM)
283
284 def test_create_expiry_time(self) -> None:
257285 """Check you can't create a token with an invalid expiry_time."""
258286 # Should fail with a time in the past
259287 channel = self.make_request(
262290 {"expiry_time": self.clock.time_msec() - 10000},
263291 access_token=self.admin_user_tok,
264292 )
265 self.assertEqual(400, int(channel.result["code"]), msg=channel.result["body"])
293 self.assertEqual(
294 HTTPStatus.BAD_REQUEST,
295 channel.code,
296 msg=channel.json_body,
297 )
266298 self.assertEqual(channel.json_body["errcode"], Codes.INVALID_PARAM)
267299
268300 # Should fail with float
272304 {"expiry_time": self.clock.time_msec() + 1000000.5},
273305 access_token=self.admin_user_tok,
274306 )
275 self.assertEqual(400, int(channel.result["code"]), msg=channel.result["body"])
276 self.assertEqual(channel.json_body["errcode"], Codes.INVALID_PARAM)
277
278 def test_create_length(self):
307 self.assertEqual(
308 HTTPStatus.BAD_REQUEST,
309 channel.code,
310 msg=channel.json_body,
311 )
312 self.assertEqual(channel.json_body["errcode"], Codes.INVALID_PARAM)
313
314 def test_create_length(self) -> None:
279315 """Check you can only generate a token with a valid length."""
280316 # Should work with 64
281317 channel = self.make_request(
284320 {"length": 64},
285321 access_token=self.admin_user_tok,
286322 )
287 self.assertEqual(200, int(channel.result["code"]), msg=channel.result["body"])
323 self.assertEqual(HTTPStatus.OK, channel.code, msg=channel.json_body)
288324 self.assertEqual(len(channel.json_body["token"]), 64)
289325
290326 # Should fail with 0
294330 {"length": 0},
295331 access_token=self.admin_user_tok,
296332 )
297 self.assertEqual(400, int(channel.result["code"]), msg=channel.result["body"])
333 self.assertEqual(
334 HTTPStatus.BAD_REQUEST,
335 channel.code,
336 msg=channel.json_body,
337 )
298338 self.assertEqual(channel.json_body["errcode"], Codes.INVALID_PARAM)
299339
300340 # Should fail with a negative integer
304344 {"length": -5},
305345 access_token=self.admin_user_tok,
306346 )
307 self.assertEqual(400, int(channel.result["code"]), msg=channel.result["body"])
347 self.assertEqual(
348 HTTPStatus.BAD_REQUEST,
349 channel.code,
350 msg=channel.json_body,
351 )
308352 self.assertEqual(channel.json_body["errcode"], Codes.INVALID_PARAM)
309353
310354 # Should fail with a float
314358 {"length": 8.5},
315359 access_token=self.admin_user_tok,
316360 )
317 self.assertEqual(400, int(channel.result["code"]), msg=channel.result["body"])
361 self.assertEqual(
362 HTTPStatus.BAD_REQUEST,
363 channel.code,
364 msg=channel.json_body,
365 )
318366 self.assertEqual(channel.json_body["errcode"], Codes.INVALID_PARAM)
319367
320368 # Should fail with 65
324372 {"length": 65},
325373 access_token=self.admin_user_tok,
326374 )
327 self.assertEqual(400, int(channel.result["code"]), msg=channel.result["body"])
375 self.assertEqual(
376 HTTPStatus.BAD_REQUEST,
377 channel.code,
378 msg=channel.json_body,
379 )
328380 self.assertEqual(channel.json_body["errcode"], Codes.INVALID_PARAM)
329381
330382 # UPDATING
331383
332 def test_update_no_auth(self):
384 def test_update_no_auth(self) -> None:
333385 """Try to update a token without authentication."""
334386 channel = self.make_request(
335387 "PUT",
336388 self.url + "/1234", # Token doesn't exist but that doesn't matter
337389 {},
338390 )
339 self.assertEqual(401, int(channel.result["code"]), msg=channel.result["body"])
391 self.assertEqual(
392 HTTPStatus.UNAUTHORIZED,
393 channel.code,
394 msg=channel.json_body,
395 )
340396 self.assertEqual(Codes.MISSING_TOKEN, channel.json_body["errcode"])
341397
342 def test_update_requester_not_admin(self):
398 def test_update_requester_not_admin(self) -> None:
343399 """Try to update a token while not an admin."""
344400 channel = self.make_request(
345401 "PUT",
347403 {},
348404 access_token=self.other_user_tok,
349405 )
350 self.assertEqual(403, int(channel.result["code"]), msg=channel.result["body"])
406 self.assertEqual(
407 HTTPStatus.FORBIDDEN,
408 channel.code,
409 msg=channel.json_body,
410 )
351411 self.assertEqual(Codes.FORBIDDEN, channel.json_body["errcode"])
352412
353 def test_update_non_existent(self):
413 def test_update_non_existent(self) -> None:
354414 """Try to update a token that doesn't exist."""
355415 channel = self.make_request(
356416 "PUT",
359419 access_token=self.admin_user_tok,
360420 )
361421
362 self.assertEqual(404, int(channel.result["code"]), msg=channel.result["body"])
422 self.assertEqual(
423 HTTPStatus.NOT_FOUND,
424 channel.code,
425 msg=channel.json_body,
426 )
363427 self.assertEqual(channel.json_body["errcode"], Codes.NOT_FOUND)
364428
365 def test_update_uses_allowed(self):
429 def test_update_uses_allowed(self) -> None:
366430 """Test updating just uses_allowed."""
367431 # Create new token using default values
368432 token = self._new_token()
374438 {"uses_allowed": 1},
375439 access_token=self.admin_user_tok,
376440 )
377 self.assertEqual(200, int(channel.result["code"]), msg=channel.result["body"])
441 self.assertEqual(HTTPStatus.OK, channel.code, msg=channel.json_body)
378442 self.assertEqual(channel.json_body["uses_allowed"], 1)
379443 self.assertIsNone(channel.json_body["expiry_time"])
380444
385449 {"uses_allowed": 0},
386450 access_token=self.admin_user_tok,
387451 )
388 self.assertEqual(200, int(channel.result["code"]), msg=channel.result["body"])
452 self.assertEqual(HTTPStatus.OK, channel.code, msg=channel.json_body)
389453 self.assertEqual(channel.json_body["uses_allowed"], 0)
390454 self.assertIsNone(channel.json_body["expiry_time"])
391455
396460 {"uses_allowed": None},
397461 access_token=self.admin_user_tok,
398462 )
399 self.assertEqual(200, int(channel.result["code"]), msg=channel.result["body"])
463 self.assertEqual(HTTPStatus.OK, channel.code, msg=channel.json_body)
400464 self.assertIsNone(channel.json_body["uses_allowed"])
401465 self.assertIsNone(channel.json_body["expiry_time"])
402466
407471 {"uses_allowed": 1.5},
408472 access_token=self.admin_user_tok,
409473 )
410 self.assertEqual(400, int(channel.result["code"]), msg=channel.result["body"])
474 self.assertEqual(
475 HTTPStatus.BAD_REQUEST,
476 channel.code,
477 msg=channel.json_body,
478 )
411479 self.assertEqual(channel.json_body["errcode"], Codes.INVALID_PARAM)
412480
413481 # Should fail with a negative integer
417485 {"uses_allowed": -5},
418486 access_token=self.admin_user_tok,
419487 )
420 self.assertEqual(400, int(channel.result["code"]), msg=channel.result["body"])
421 self.assertEqual(channel.json_body["errcode"], Codes.INVALID_PARAM)
422
423 def test_update_expiry_time(self):
488 self.assertEqual(
489 HTTPStatus.BAD_REQUEST,
490 channel.code,
491 msg=channel.json_body,
492 )
493 self.assertEqual(channel.json_body["errcode"], Codes.INVALID_PARAM)
494
495 def test_update_expiry_time(self) -> None:
424496 """Test updating just expiry_time."""
425497 # Create new token using default values
426498 token = self._new_token()
433505 {"expiry_time": new_expiry_time},
434506 access_token=self.admin_user_tok,
435507 )
436 self.assertEqual(200, int(channel.result["code"]), msg=channel.result["body"])
508 self.assertEqual(HTTPStatus.OK, channel.code, msg=channel.json_body)
437509 self.assertEqual(channel.json_body["expiry_time"], new_expiry_time)
438510 self.assertIsNone(channel.json_body["uses_allowed"])
439511
444516 {"expiry_time": None},
445517 access_token=self.admin_user_tok,
446518 )
447 self.assertEqual(200, int(channel.result["code"]), msg=channel.result["body"])
519 self.assertEqual(HTTPStatus.OK, channel.code, msg=channel.json_body)
448520 self.assertIsNone(channel.json_body["expiry_time"])
449521 self.assertIsNone(channel.json_body["uses_allowed"])
450522
456528 {"expiry_time": past_time},
457529 access_token=self.admin_user_tok,
458530 )
459 self.assertEqual(400, int(channel.result["code"]), msg=channel.result["body"])
531 self.assertEqual(
532 HTTPStatus.BAD_REQUEST,
533 channel.code,
534 msg=channel.json_body,
535 )
460536 self.assertEqual(channel.json_body["errcode"], Codes.INVALID_PARAM)
461537
462538 # Should fail a float
466542 {"expiry_time": new_expiry_time + 0.5},
467543 access_token=self.admin_user_tok,
468544 )
469 self.assertEqual(400, int(channel.result["code"]), msg=channel.result["body"])
470 self.assertEqual(channel.json_body["errcode"], Codes.INVALID_PARAM)
471
472 def test_update_both(self):
545 self.assertEqual(
546 HTTPStatus.BAD_REQUEST,
547 channel.code,
548 msg=channel.json_body,
549 )
550 self.assertEqual(channel.json_body["errcode"], Codes.INVALID_PARAM)
551
552 def test_update_both(self) -> None:
473553 """Test updating both uses_allowed and expiry_time."""
474554 # Create new token using default values
475555 token = self._new_token()
487567 access_token=self.admin_user_tok,
488568 )
489569
490 self.assertEqual(200, int(channel.result["code"]), msg=channel.result["body"])
570 self.assertEqual(HTTPStatus.OK, channel.code, msg=channel.json_body)
491571 self.assertEqual(channel.json_body["uses_allowed"], 1)
492572 self.assertEqual(channel.json_body["expiry_time"], new_expiry_time)
493573
494 def test_update_invalid_type(self):
574 def test_update_invalid_type(self) -> None:
495575 """Test using invalid types doesn't work."""
496576 # Create new token using default values
497577 token = self._new_token()
508588 access_token=self.admin_user_tok,
509589 )
510590
511 self.assertEqual(400, int(channel.result["code"]), msg=channel.result["body"])
591 self.assertEqual(
592 HTTPStatus.BAD_REQUEST,
593 channel.code,
594 msg=channel.json_body,
595 )
512596 self.assertEqual(channel.json_body["errcode"], Codes.INVALID_PARAM)
513597
514598 # DELETING
515599
516 def test_delete_no_auth(self):
600 def test_delete_no_auth(self) -> None:
517601 """Try to delete a token without authentication."""
518602 channel = self.make_request(
519603 "DELETE",
520604 self.url + "/1234", # Token doesn't exist but that doesn't matter
521605 {},
522606 )
523 self.assertEqual(401, int(channel.result["code"]), msg=channel.result["body"])
607 self.assertEqual(
608 HTTPStatus.UNAUTHORIZED,
609 channel.code,
610 msg=channel.json_body,
611 )
524612 self.assertEqual(Codes.MISSING_TOKEN, channel.json_body["errcode"])
525613
526 def test_delete_requester_not_admin(self):
614 def test_delete_requester_not_admin(self) -> None:
527615 """Try to delete a token while not an admin."""
528616 channel = self.make_request(
529617 "DELETE",
531619 {},
532620 access_token=self.other_user_tok,
533621 )
534 self.assertEqual(403, int(channel.result["code"]), msg=channel.result["body"])
622 self.assertEqual(
623 HTTPStatus.FORBIDDEN,
624 channel.code,
625 msg=channel.json_body,
626 )
535627 self.assertEqual(Codes.FORBIDDEN, channel.json_body["errcode"])
536628
537 def test_delete_non_existent(self):
629 def test_delete_non_existent(self) -> None:
538630 """Try to delete a token that doesn't exist."""
539631 channel = self.make_request(
540632 "DELETE",
543635 access_token=self.admin_user_tok,
544636 )
545637
546 self.assertEqual(404, int(channel.result["code"]), msg=channel.result["body"])
638 self.assertEqual(
639 HTTPStatus.NOT_FOUND,
640 channel.code,
641 msg=channel.json_body,
642 )
547643 self.assertEqual(channel.json_body["errcode"], Codes.NOT_FOUND)
548644
549 def test_delete(self):
645 def test_delete(self) -> None:
550646 """Test deleting a token."""
551647 # Create new token using default values
552648 token = self._new_token()
558654 access_token=self.admin_user_tok,
559655 )
560656
561 self.assertEqual(200, int(channel.result["code"]), msg=channel.result["body"])
657 self.assertEqual(HTTPStatus.OK, channel.code, msg=channel.json_body)
562658
563659 # GETTING ONE
564660
565 def test_get_no_auth(self):
661 def test_get_no_auth(self) -> None:
566662 """Try to get a token without authentication."""
567663 channel = self.make_request(
568664 "GET",
569665 self.url + "/1234", # Token doesn't exist but that doesn't matter
570666 {},
571667 )
572 self.assertEqual(401, int(channel.result["code"]), msg=channel.result["body"])
668 self.assertEqual(
669 HTTPStatus.UNAUTHORIZED,
670 channel.code,
671 msg=channel.json_body,
672 )
573673 self.assertEqual(Codes.MISSING_TOKEN, channel.json_body["errcode"])
574674
575 def test_get_requester_not_admin(self):
675 def test_get_requester_not_admin(self) -> None:
576676 """Try to get a token while not an admin."""
577677 channel = self.make_request(
578678 "GET",
580680 {},
581681 access_token=self.other_user_tok,
582682 )
583 self.assertEqual(403, int(channel.result["code"]), msg=channel.result["body"])
683 self.assertEqual(
684 HTTPStatus.FORBIDDEN,
685 channel.code,
686 msg=channel.json_body,
687 )
584688 self.assertEqual(Codes.FORBIDDEN, channel.json_body["errcode"])
585689
586 def test_get_non_existent(self):
690 def test_get_non_existent(self) -> None:
587691 """Try to get a token that doesn't exist."""
588692 channel = self.make_request(
589693 "GET",
592696 access_token=self.admin_user_tok,
593697 )
594698
595 self.assertEqual(404, int(channel.result["code"]), msg=channel.result["body"])
699 self.assertEqual(
700 HTTPStatus.NOT_FOUND,
701 channel.code,
702 msg=channel.json_body,
703 )
596704 self.assertEqual(channel.json_body["errcode"], Codes.NOT_FOUND)
597705
598 def test_get(self):
706 def test_get(self) -> None:
599707 """Test getting a token."""
600708 # Create new token using default values
601709 token = self._new_token()
607715 access_token=self.admin_user_tok,
608716 )
609717
610 self.assertEqual(200, int(channel.result["code"]), msg=channel.result["body"])
718 self.assertEqual(HTTPStatus.OK, channel.code, msg=channel.json_body)
611719 self.assertEqual(channel.json_body["token"], token)
612720 self.assertIsNone(channel.json_body["uses_allowed"])
613721 self.assertIsNone(channel.json_body["expiry_time"])
616724
617725 # LISTING
618726
619 def test_list_no_auth(self):
727 def test_list_no_auth(self) -> None:
620728 """Try to list tokens without authentication."""
621729 channel = self.make_request("GET", self.url, {})
622 self.assertEqual(401, int(channel.result["code"]), msg=channel.result["body"])
730 self.assertEqual(
731 HTTPStatus.UNAUTHORIZED,
732 channel.code,
733 msg=channel.json_body,
734 )
623735 self.assertEqual(Codes.MISSING_TOKEN, channel.json_body["errcode"])
624736
625 def test_list_requester_not_admin(self):
737 def test_list_requester_not_admin(self) -> None:
626738 """Try to list tokens while not an admin."""
627739 channel = self.make_request(
628740 "GET",
630742 {},
631743 access_token=self.other_user_tok,
632744 )
633 self.assertEqual(403, int(channel.result["code"]), msg=channel.result["body"])
745 self.assertEqual(
746 HTTPStatus.FORBIDDEN,
747 channel.code,
748 msg=channel.json_body,
749 )
634750 self.assertEqual(Codes.FORBIDDEN, channel.json_body["errcode"])
635751
636 def test_list_all(self):
752 def test_list_all(self) -> None:
637753 """Test listing all tokens."""
638754 # Create new token using default values
639755 token = self._new_token()
645761 access_token=self.admin_user_tok,
646762 )
647763
648 self.assertEqual(200, int(channel.result["code"]), msg=channel.result["body"])
764 self.assertEqual(HTTPStatus.OK, channel.code, msg=channel.json_body)
649765 self.assertEqual(len(channel.json_body["registration_tokens"]), 1)
650766 token_info = channel.json_body["registration_tokens"][0]
651767 self.assertEqual(token_info["token"], token)
654770 self.assertEqual(token_info["pending"], 0)
655771 self.assertEqual(token_info["completed"], 0)
656772
657 def test_list_invalid_query_parameter(self):
773 def test_list_invalid_query_parameter(self) -> None:
658774 """Test with `valid` query parameter not `true` or `false`."""
659775 channel = self.make_request(
660776 "GET",
663779 access_token=self.admin_user_tok,
664780 )
665781
666 self.assertEqual(400, int(channel.result["code"]), msg=channel.result["body"])
667
668 def _test_list_query_parameter(self, valid: str):
782 self.assertEqual(
783 HTTPStatus.BAD_REQUEST,
784 channel.code,
785 msg=channel.json_body,
786 )
787
788 def _test_list_query_parameter(self, valid: str) -> None:
669789 """Helper used to test both valid=true and valid=false."""
670790 # Create 2 valid and 2 invalid tokens.
671791 now = self.hs.get_clock().time_msec()
695815 access_token=self.admin_user_tok,
696816 )
697817
698 self.assertEqual(200, int(channel.result["code"]), msg=channel.result["body"])
818 self.assertEqual(HTTPStatus.OK, channel.code, msg=channel.json_body)
699819 self.assertEqual(len(channel.json_body["registration_tokens"]), 2)
700820 token_info_1 = channel.json_body["registration_tokens"][0]
701821 token_info_2 = channel.json_body["registration_tokens"][1]
702822 self.assertIn(token_info_1["token"], tokens)
703823 self.assertIn(token_info_2["token"], tokens)
704824
705 def test_list_valid(self):
825 def test_list_valid(self) -> None:
706826 """Test listing just valid tokens."""
707827 self._test_list_query_parameter(valid="true")
708828
709 def test_list_invalid(self):
829 def test_list_invalid(self) -> None:
710830 """Test listing just invalid tokens."""
711831 self._test_list_query_parameter(valid="false")
1010 # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
1111 # See the License for the specific language governing permissions and
1212 # limitations under the License.
13
14 import json
1513 import urllib.parse
1614 from http import HTTPStatus
1715 from typing import List, Optional
1816 from unittest.mock import Mock
1917
2018 from parameterized import parameterized
19
20 from twisted.test.proto_helpers import MemoryReactor
2121
2222 import synapse.rest.admin
2323 from synapse.api.constants import EventTypes, Membership
2424 from synapse.api.errors import Codes
2525 from synapse.handlers.pagination import PaginationHandler
2626 from synapse.rest.client import directory, events, login, room
27 from synapse.server import HomeServer
28 from synapse.util import Clock
2729
2830 from tests import unittest
2931
3941 room.register_deprecated_servlets,
4042 ]
4143
42 def prepare(self, reactor, clock, hs):
44 def prepare(self, reactor: MemoryReactor, clock: Clock, hs: HomeServer) -> None:
4345 self.event_creation_handler = hs.get_event_creation_handler()
4446 hs.config.consent.user_consent_version = "1"
4547
6567
6668 def test_requester_is_no_admin(self):
6769 """
68 If the user is not a server admin, an error 403 is returned.
70 If the user is not a server admin, an error HTTPStatus.FORBIDDEN is returned.
6971 """
7072
7173 channel = self.make_request(
7577 access_token=self.other_user_tok,
7678 )
7779
78 self.assertEqual(403, channel.code, msg=channel.json_body)
80 self.assertEqual(HTTPStatus.FORBIDDEN, channel.code, msg=channel.json_body)
7981 self.assertEqual(Codes.FORBIDDEN, channel.json_body["errcode"])
8082
8183 def test_room_does_not_exist(self):
8284 """
83 Check that unknown rooms/server return error 404.
85 Check that unknown rooms/server return 200
8486 """
8587 url = "/_synapse/admin/v1/rooms/%s" % "!unknown:test"
8688
9193 access_token=self.admin_user_tok,
9294 )
9395
94 self.assertEqual(404, channel.code, msg=channel.json_body)
95 self.assertEqual(Codes.NOT_FOUND, channel.json_body["errcode"])
96 self.assertEqual(HTTPStatus.OK, channel.code, msg=channel.json_body)
9697
9798 def test_room_is_not_valid(self):
9899 """
99 Check that invalid room names, return an error 400.
100 Check that invalid room names, return an error HTTPStatus.BAD_REQUEST.
100101 """
101102 url = "/_synapse/admin/v1/rooms/%s" % "invalidroom"
102103
107108 access_token=self.admin_user_tok,
108109 )
109110
110 self.assertEqual(400, channel.code, msg=channel.json_body)
111 self.assertEqual(HTTPStatus.BAD_REQUEST, channel.code, msg=channel.json_body)
111112 self.assertEqual(
112113 "invalidroom is not a legal room ID",
113114 channel.json_body["error"],
117118 """
118119 Tests that the user ID must be from local server but it does not have to exist.
119120 """
120 body = json.dumps({"new_room_user_id": "@unknown:test"})
121121
122122 channel = self.make_request(
123123 "DELETE",
124124 self.url,
125 content=body,
126 access_token=self.admin_user_tok,
127 )
128
129 self.assertEqual(200, channel.code, msg=channel.json_body)
125 content={"new_room_user_id": "@unknown:test"},
126 access_token=self.admin_user_tok,
127 )
128
129 self.assertEqual(HTTPStatus.OK, channel.code, msg=channel.json_body)
130130 self.assertIn("new_room_id", channel.json_body)
131131 self.assertIn("kicked_users", channel.json_body)
132132 self.assertIn("failed_to_kick_users", channel.json_body)
136136 """
137137 Check that only local users can create new room to move members.
138138 """
139 body = json.dumps({"new_room_user_id": "@not:exist.bla"})
140139
141140 channel = self.make_request(
142141 "DELETE",
143142 self.url,
144 content=body,
145 access_token=self.admin_user_tok,
146 )
147
148 self.assertEqual(400, channel.code, msg=channel.json_body)
143 content={"new_room_user_id": "@not:exist.bla"},
144 access_token=self.admin_user_tok,
145 )
146
147 self.assertEqual(HTTPStatus.BAD_REQUEST, channel.code, msg=channel.json_body)
149148 self.assertEqual(
150149 "User must be our own: @not:exist.bla",
151150 channel.json_body["error"],
155154 """
156155 If parameter `block` is not boolean, return an error
157156 """
158 body = json.dumps({"block": "NotBool"})
159157
160158 channel = self.make_request(
161159 "DELETE",
162160 self.url,
163 content=body,
164 access_token=self.admin_user_tok,
165 )
166
167 self.assertEqual(400, channel.code, msg=channel.json_body)
161 content={"block": "NotBool"},
162 access_token=self.admin_user_tok,
163 )
164
165 self.assertEqual(HTTPStatus.BAD_REQUEST, channel.code, msg=channel.json_body)
168166 self.assertEqual(Codes.BAD_JSON, channel.json_body["errcode"])
169167
170168 def test_purge_is_not_bool(self):
171169 """
172170 If parameter `purge` is not boolean, return an error
173171 """
174 body = json.dumps({"purge": "NotBool"})
175172
176173 channel = self.make_request(
177174 "DELETE",
178175 self.url,
179 content=body,
180 access_token=self.admin_user_tok,
181 )
182
183 self.assertEqual(400, channel.code, msg=channel.json_body)
176 content={"purge": "NotBool"},
177 access_token=self.admin_user_tok,
178 )
179
180 self.assertEqual(HTTPStatus.BAD_REQUEST, channel.code, msg=channel.json_body)
184181 self.assertEqual(Codes.BAD_JSON, channel.json_body["errcode"])
185182
186183 def test_purge_room_and_block(self):
197194 # Assert one user in room
198195 self._is_member(room_id=self.room_id, user_id=self.other_user)
199196
200 body = json.dumps({"block": True, "purge": True})
201
202197 channel = self.make_request(
203198 "DELETE",
204199 self.url.encode("ascii"),
205 content=body,
206 access_token=self.admin_user_tok,
207 )
208
209 self.assertEqual(200, channel.code, msg=channel.json_body)
200 content={"block": True, "purge": True},
201 access_token=self.admin_user_tok,
202 )
203
204 self.assertEqual(HTTPStatus.OK, channel.code, msg=channel.json_body)
210205 self.assertEqual(None, channel.json_body["new_room_id"])
211206 self.assertEqual(self.other_user, channel.json_body["kicked_users"][0])
212207 self.assertIn("failed_to_kick_users", channel.json_body)
230225 # Assert one user in room
231226 self._is_member(room_id=self.room_id, user_id=self.other_user)
232227
233 body = json.dumps({"block": False, "purge": True})
234
235228 channel = self.make_request(
236229 "DELETE",
237230 self.url.encode("ascii"),
238 content=body,
239 access_token=self.admin_user_tok,
240 )
241
242 self.assertEqual(200, channel.code, msg=channel.json_body)
231 content={"block": False, "purge": True},
232 access_token=self.admin_user_tok,
233 )
234
235 self.assertEqual(HTTPStatus.OK, channel.code, msg=channel.json_body)
243236 self.assertEqual(None, channel.json_body["new_room_id"])
244237 self.assertEqual(self.other_user, channel.json_body["kicked_users"][0])
245238 self.assertIn("failed_to_kick_users", channel.json_body)
264257 # Assert one user in room
265258 self._is_member(room_id=self.room_id, user_id=self.other_user)
266259
267 body = json.dumps({"block": True, "purge": False})
268
269260 channel = self.make_request(
270261 "DELETE",
271262 self.url.encode("ascii"),
272 content=body,
273 access_token=self.admin_user_tok,
274 )
275
276 self.assertEqual(200, channel.code, msg=channel.json_body)
263 content={"block": True, "purge": False},
264 access_token=self.admin_user_tok,
265 )
266
267 self.assertEqual(HTTPStatus.OK, channel.code, msg=channel.json_body)
277268 self.assertEqual(None, channel.json_body["new_room_id"])
278269 self.assertEqual(self.other_user, channel.json_body["kicked_users"][0])
279270 self.assertIn("failed_to_kick_users", channel.json_body)
304295 )
305296
306297 # The room is now blocked.
307 self.assertEqual(
308 HTTPStatus.OK, int(channel.result["code"]), msg=channel.result["body"]
309 )
298 self.assertEqual(HTTPStatus.OK, channel.code, msg=channel.json_body)
310299 self._is_blocked(room_id)
311300
312301 def test_shutdown_room_consent(self):
326315
327316 # Assert that the user is getting consent error
328317 self.helper.send(
329 self.room_id, body="foo", tok=self.other_user_tok, expect_code=403
318 self.room_id,
319 body="foo",
320 tok=self.other_user_tok,
321 expect_code=HTTPStatus.FORBIDDEN,
330322 )
331323
332324 # Test that room is not purged
340332 channel = self.make_request(
341333 "DELETE",
342334 self.url,
343 json.dumps({"new_room_user_id": self.admin_user}),
344 access_token=self.admin_user_tok,
345 )
346
347 self.assertEqual(200, channel.code, msg=channel.json_body)
335 {"new_room_user_id": self.admin_user},
336 access_token=self.admin_user_tok,
337 )
338
339 self.assertEqual(HTTPStatus.OK, channel.code, msg=channel.json_body)
348340 self.assertEqual(self.other_user, channel.json_body["kicked_users"][0])
349341 self.assertIn("new_room_id", channel.json_body)
350342 self.assertIn("failed_to_kick_users", channel.json_body)
370362 channel = self.make_request(
371363 "PUT",
372364 url.encode("ascii"),
373 json.dumps({"history_visibility": "world_readable"}),
365 {"history_visibility": "world_readable"},
374366 access_token=self.other_user_tok,
375367 )
376 self.assertEqual(200, channel.code, msg=channel.json_body)
368 self.assertEqual(HTTPStatus.OK, channel.code, msg=channel.json_body)
377369
378370 # Test that room is not purged
379371 with self.assertRaises(AssertionError):
386378 channel = self.make_request(
387379 "DELETE",
388380 self.url,
389 json.dumps({"new_room_user_id": self.admin_user}),
390 access_token=self.admin_user_tok,
391 )
392
393 self.assertEqual(200, channel.code, msg=channel.json_body)
381 {"new_room_user_id": self.admin_user},
382 access_token=self.admin_user_tok,
383 )
384
385 self.assertEqual(HTTPStatus.OK, channel.code, msg=channel.json_body)
394386 self.assertEqual(self.other_user, channel.json_body["kicked_users"][0])
395387 self.assertIn("new_room_id", channel.json_body)
396388 self.assertIn("failed_to_kick_users", channel.json_body)
405397 self._has_no_members(self.room_id)
406398
407399 # Assert we can no longer peek into the room
408 self._assert_peek(self.room_id, expect_code=403)
400 self._assert_peek(self.room_id, expect_code=HTTPStatus.FORBIDDEN)
409401
410402 def _is_blocked(self, room_id, expect=True):
411403 """Assert that the room is blocked or not"""
464456 room.register_deprecated_servlets,
465457 ]
466458
467 def prepare(self, reactor, clock, hs):
459 def prepare(self, reactor: MemoryReactor, clock: Clock, hs: HomeServer) -> None:
468460 self.event_creation_handler = hs.get_event_creation_handler()
469461 hs.config.consent.user_consent_version = "1"
470462
501493 )
502494 def test_requester_is_no_admin(self, method: str, url: str):
503495 """
504 If the user is not a server admin, an error 403 is returned.
496 If the user is not a server admin, an error HTTPStatus.FORBIDDEN is returned.
505497 """
506498
507499 channel = self.make_request(
514506 self.assertEqual(HTTPStatus.FORBIDDEN, channel.code, msg=channel.json_body)
515507 self.assertEqual(Codes.FORBIDDEN, channel.json_body["errcode"])
516508
517 @parameterized.expand(
518 [
519 ("DELETE", "/_synapse/admin/v2/rooms/%s"),
520 ("GET", "/_synapse/admin/v2/rooms/%s/delete_status"),
521 ("GET", "/_synapse/admin/v2/rooms/delete_status/%s"),
522 ]
523 )
524 def test_room_does_not_exist(self, method: str, url: str):
525 """
526 Check that unknown rooms/server return error 404.
527 """
528
529 channel = self.make_request(
530 method,
531 url % "!unknown:test",
509 def test_room_does_not_exist(self):
510 """
511 Check that unknown rooms/server return 200
512
513 This is important, as it allows incomplete vestiges of rooms to be cleared up
514 even if the create event/etc is missing.
515 """
516 room_id = "!unknown:test"
517 channel = self.make_request(
518 "DELETE",
519 f"/_synapse/admin/v2/rooms/{room_id}",
532520 content={},
533521 access_token=self.admin_user_tok,
534522 )
535523
536 self.assertEqual(HTTPStatus.NOT_FOUND, channel.code, msg=channel.json_body)
537 self.assertEqual(Codes.NOT_FOUND, channel.json_body["errcode"])
524 self.assertEqual(HTTPStatus.OK, channel.code, msg=channel.json_body)
525 self.assertIn("delete_id", channel.json_body)
526 delete_id = channel.json_body["delete_id"]
527
528 # get status
529 channel = self.make_request(
530 "GET",
531 f"/_synapse/admin/v2/rooms/{room_id}/delete_status",
532 access_token=self.admin_user_tok,
533 )
534
535 self.assertEqual(HTTPStatus.OK, channel.code, msg=channel.json_body)
536 self.assertEqual(1, len(channel.json_body["results"]))
537 self.assertEqual("complete", channel.json_body["results"][0]["status"])
538 self.assertEqual(delete_id, channel.json_body["results"][0]["delete_id"])
538539
539540 @parameterized.expand(
540541 [
544545 )
545546 def test_room_is_not_valid(self, method: str, url: str):
546547 """
547 Check that invalid room names, return an error 400.
548 Check that invalid room names, return an error HTTPStatus.BAD_REQUEST.
548549 """
549550
550551 channel = self.make_request(
853854
854855 # Assert that the user is getting consent error
855856 self.helper.send(
856 self.room_id, body="foo", tok=self.other_user_tok, expect_code=403
857 self.room_id,
858 body="foo",
859 tok=self.other_user_tok,
860 expect_code=HTTPStatus.FORBIDDEN,
857861 )
858862
859863 # Test that room is not purged
950954 self._has_no_members(self.room_id)
951955
952956 # Assert we can no longer peek into the room
953 self._assert_peek(self.room_id, expect_code=403)
957 self._assert_peek(self.room_id, expect_code=HTTPStatus.FORBIDDEN)
954958
955959 def _is_blocked(self, room_id: str, expect: bool = True) -> None:
956960 """Assert that the room is blocked or not"""
10681072 directory.register_servlets,
10691073 ]
10701074
1071 def prepare(self, reactor, clock, hs):
1075 def prepare(self, reactor: MemoryReactor, clock: Clock, hs: HomeServer) -> None:
10721076 # Create user
10731077 self.admin_user = self.register_user("admin", "pass", admin=True)
10741078 self.admin_user_tok = self.login("admin", "pass")
10751079
1076 def test_list_rooms(self):
1080 def test_list_rooms(self) -> None:
10771081 """Test that we can list rooms"""
10781082 # Create 3 test rooms
10791083 total_rooms = 3
10931097 )
10941098
10951099 # Check request completed successfully
1096 self.assertEqual(200, channel.code, msg=channel.json_body)
1100 self.assertEqual(HTTPStatus.OK, channel.code, msg=channel.json_body)
10971101
10981102 # Check that response json body contains a "rooms" key
10991103 self.assertTrue(
11371141 # We shouldn't receive a next token here as there's no further rooms to show
11381142 self.assertNotIn("next_batch", channel.json_body)
11391143
1140 def test_list_rooms_pagination(self):
1144 def test_list_rooms_pagination(self) -> None:
11411145 """Test that we can get a full list of rooms through pagination"""
11421146 # Create 5 test rooms
11431147 total_rooms = 5
11771181 url.encode("ascii"),
11781182 access_token=self.admin_user_tok,
11791183 )
1180 self.assertEqual(200, channel.code, msg=channel.json_body)
1184 self.assertEqual(HTTPStatus.OK, channel.code, msg=channel.json_body)
11811185
11821186 self.assertTrue("rooms" in channel.json_body)
11831187 for r in channel.json_body["rooms"]:
12171221 url.encode("ascii"),
12181222 access_token=self.admin_user_tok,
12191223 )
1220 self.assertEqual(200, channel.code, msg=channel.json_body)
1221
1222 def test_correct_room_attributes(self):
1224 self.assertEqual(HTTPStatus.OK, channel.code, msg=channel.json_body)
1225
1226 def test_correct_room_attributes(self) -> None:
12231227 """Test the correct attributes for a room are returned"""
12241228 # Create a test room
12251229 room_id = self.helper.create_room_as(self.admin_user, tok=self.admin_user_tok)
12401244 {"room_id": room_id},
12411245 access_token=self.admin_user_tok,
12421246 )
1243 self.assertEqual(200, channel.code, msg=channel.json_body)
1247 self.assertEqual(HTTPStatus.OK, channel.code, msg=channel.json_body)
12441248
12451249 # Set this new alias as the canonical alias for this room
12461250 self.helper.send_state(
12721276 url.encode("ascii"),
12731277 access_token=self.admin_user_tok,
12741278 )
1275 self.assertEqual(200, channel.code, msg=channel.json_body)
1279 self.assertEqual(HTTPStatus.OK, channel.code, msg=channel.json_body)
12761280
12771281 # Check that rooms were returned
12781282 self.assertTrue("rooms" in channel.json_body)
13001304 self.assertEqual(test_room_name, r["name"])
13011305 self.assertEqual(test_alias, r["canonical_alias"])
13021306
1303 def test_room_list_sort_order(self):
1307 def test_room_list_sort_order(self) -> None:
13041308 """Test room list sort ordering. alphabetical name versus number of members,
13051309 reversing the order, etc.
13061310 """
13091313 order_type: str,
13101314 expected_room_list: List[str],
13111315 reverse: bool = False,
1312 ):
1316 ) -> None:
13131317 """Request the list of rooms in a certain order. Assert that order is what
13141318 we expect
13151319
13271331 url.encode("ascii"),
13281332 access_token=self.admin_user_tok,
13291333 )
1330 self.assertEqual(200, channel.code, msg=channel.json_body)
1334 self.assertEqual(HTTPStatus.OK, channel.code, msg=channel.json_body)
13311335
13321336 # Check that rooms were returned
13331337 self.assertTrue("rooms" in channel.json_body)
14381442 _order_test("state_events", [room_id_3, room_id_2, room_id_1])
14391443 _order_test("state_events", [room_id_1, room_id_2, room_id_3], reverse=True)
14401444
1441 def test_search_term(self):
1445 def test_search_term(self) -> None:
14421446 """Test that searching for a room works correctly"""
14431447 # Create two test rooms
14441448 room_id_1 = self.helper.create_room_as(self.admin_user, tok=self.admin_user_tok)
14661470 def _search_test(
14671471 expected_room_id: Optional[str],
14681472 search_term: str,
1469 expected_http_code: int = 200,
1470 ):
1473 expected_http_code: int = HTTPStatus.OK,
1474 ) -> None:
14711475 """Search for a room and check that the returned room's id is a match
14721476
14731477 Args:
14841488 )
14851489 self.assertEqual(expected_http_code, channel.code, msg=channel.json_body)
14861490
1487 if expected_http_code != 200:
1491 if expected_http_code != HTTPStatus.OK:
14881492 return
14891493
14901494 # Check that rooms were returned
15271531
15281532 _search_test(None, "foo")
15291533 _search_test(None, "bar")
1530 _search_test(None, "", expected_http_code=400)
1534 _search_test(None, "", expected_http_code=HTTPStatus.BAD_REQUEST)
15311535
15321536 # Test that the whole room id returns the room
15331537 _search_test(room_id_1, room_id_1)
15411545 # Test search local part of alias
15421546 _search_test(room_id_1, "alias1")
15431547
1544 def test_search_term_non_ascii(self):
1548 def test_search_term_non_ascii(self) -> None:
15451549 """Test that searching for a room with non-ASCII characters works correctly"""
15461550
15471551 # Create test room
15641568 url.encode("ascii"),
15651569 access_token=self.admin_user_tok,
15661570 )
1567 self.assertEqual(200, channel.code, msg=channel.json_body)
1571 self.assertEqual(HTTPStatus.OK, channel.code, msg=channel.json_body)
15681572 self.assertEqual(room_id, channel.json_body.get("rooms")[0].get("room_id"))
15691573 self.assertEqual("ж", channel.json_body.get("rooms")[0].get("name"))
15701574
1571 def test_single_room(self):
1575 def test_single_room(self) -> None:
15721576 """Test that a single room can be requested correctly"""
15731577 # Create two test rooms
15741578 room_id_1 = self.helper.create_room_as(self.admin_user, tok=self.admin_user_tok)
15971601 url.encode("ascii"),
15981602 access_token=self.admin_user_tok,
15991603 )
1600 self.assertEqual(200, channel.code, msg=channel.json_body)
1604 self.assertEqual(HTTPStatus.OK, channel.code, msg=channel.json_body)
16011605
16021606 self.assertIn("room_id", channel.json_body)
16031607 self.assertIn("name", channel.json_body)
16191623
16201624 self.assertEqual(room_id_1, channel.json_body["room_id"])
16211625
1622 def test_single_room_devices(self):
1626 def test_single_room_devices(self) -> None:
16231627 """Test that `joined_local_devices` can be requested correctly"""
16241628 room_id_1 = self.helper.create_room_as(self.admin_user, tok=self.admin_user_tok)
16251629
16291633 url.encode("ascii"),
16301634 access_token=self.admin_user_tok,
16311635 )
1632 self.assertEqual(200, channel.code, msg=channel.json_body)
1636 self.assertEqual(HTTPStatus.OK, channel.code, msg=channel.json_body)
16331637 self.assertEqual(1, channel.json_body["joined_local_devices"])
16341638
16351639 # Have another user join the room
16431647 url.encode("ascii"),
16441648 access_token=self.admin_user_tok,
16451649 )
1646 self.assertEqual(200, channel.code, msg=channel.json_body)
1650 self.assertEqual(HTTPStatus.OK, channel.code, msg=channel.json_body)
16471651 self.assertEqual(2, channel.json_body["joined_local_devices"])
16481652
16491653 # leave room
16551659 url.encode("ascii"),
16561660 access_token=self.admin_user_tok,
16571661 )
1658 self.assertEqual(200, channel.code, msg=channel.json_body)
1662 self.assertEqual(HTTPStatus.OK, channel.code, msg=channel.json_body)
16591663 self.assertEqual(0, channel.json_body["joined_local_devices"])
16601664
1661 def test_room_members(self):
1665 def test_room_members(self) -> None:
16621666 """Test that room members can be requested correctly"""
16631667 # Create two test rooms
16641668 room_id_1 = self.helper.create_room_as(self.admin_user, tok=self.admin_user_tok)
16861690 url.encode("ascii"),
16871691 access_token=self.admin_user_tok,
16881692 )
1689 self.assertEqual(200, channel.code, msg=channel.json_body)
1693 self.assertEqual(HTTPStatus.OK, channel.code, msg=channel.json_body)
16901694
16911695 self.assertCountEqual(
16921696 ["@admin:test", "@foo:test", "@bar:test"], channel.json_body["members"]
16991703 url.encode("ascii"),
17001704 access_token=self.admin_user_tok,
17011705 )
1702 self.assertEqual(200, channel.code, msg=channel.json_body)
1706 self.assertEqual(HTTPStatus.OK, channel.code, msg=channel.json_body)
17031707
17041708 self.assertCountEqual(
17051709 ["@admin:test", "@bar:test", "@foobar:test"], channel.json_body["members"]
17061710 )
17071711 self.assertEqual(channel.json_body["total"], 3)
17081712
1709 def test_room_state(self):
1713 def test_room_state(self) -> None:
17101714 """Test that room state can be requested correctly"""
17111715 # Create two test rooms
17121716 room_id = self.helper.create_room_as(self.admin_user, tok=self.admin_user_tok)
17171721 url.encode("ascii"),
17181722 access_token=self.admin_user_tok,
17191723 )
1720 self.assertEqual(200, channel.code, msg=channel.json_body)
1724 self.assertEqual(HTTPStatus.OK, channel.code, msg=channel.json_body)
17211725 self.assertIn("state", channel.json_body)
17221726 # testing that the state events match is painful and not done here. We assume that
17231727 # the create_room already does the right thing, so no need to verify that we got
17241728 # the state events it created.
17251729
1726 def _set_canonical_alias(self, room_id: str, test_alias: str, admin_user_tok: str):
1730 def _set_canonical_alias(
1731 self, room_id: str, test_alias: str, admin_user_tok: str
1732 ) -> None:
17271733 # Create a new alias to this room
17281734 url = "/_matrix/client/r0/directory/room/%s" % (urllib.parse.quote(test_alias),)
17291735 channel = self.make_request(
17321738 {"room_id": room_id},
17331739 access_token=admin_user_tok,
17341740 )
1735 self.assertEqual(200, channel.code, msg=channel.json_body)
1741 self.assertEqual(HTTPStatus.OK, channel.code, msg=channel.json_body)
17361742
17371743 # Set this new alias as the canonical alias for this room
17381744 self.helper.send_state(
17581764 login.register_servlets,
17591765 ]
17601766
1761 def prepare(self, reactor, clock, homeserver):
1767 def prepare(self, reactor: MemoryReactor, clock: Clock, hs: HomeServer) -> None:
17621768 self.admin_user = self.register_user("admin", "pass", admin=True)
17631769 self.admin_user_tok = self.login("admin", "pass")
17641770
17731779 )
17741780 self.url = f"/_synapse/admin/v1/join/{self.public_room_id}"
17751781
1776 def test_requester_is_no_admin(self):
1777 """
1778 If the user is not a server admin, an error 403 is returned.
1779 """
1780 body = json.dumps({"user_id": self.second_user_id})
1782 def test_requester_is_no_admin(self) -> None:
1783 """
1784 If the user is not a server admin, an error HTTPStatus.FORBIDDEN is returned.
1785 """
17811786
17821787 channel = self.make_request(
17831788 "POST",
17841789 self.url,
1785 content=body,
1790 content={"user_id": self.second_user_id},
17861791 access_token=self.second_tok,
17871792 )
17881793
1789 self.assertEqual(403, channel.code, msg=channel.json_body)
1794 self.assertEqual(HTTPStatus.FORBIDDEN, channel.code, msg=channel.json_body)
17901795 self.assertEqual(Codes.FORBIDDEN, channel.json_body["errcode"])
17911796
1792 def test_invalid_parameter(self):
1797 def test_invalid_parameter(self) -> None:
17931798 """
17941799 If a parameter is missing, return an error
17951800 """
1796 body = json.dumps({"unknown_parameter": "@unknown:test"})
17971801
17981802 channel = self.make_request(
17991803 "POST",
18001804 self.url,
1801 content=body,
1802 access_token=self.admin_user_tok,
1803 )
1804
1805 self.assertEqual(400, channel.code, msg=channel.json_body)
1805 content={"unknown_parameter": "@unknown:test"},
1806 access_token=self.admin_user_tok,
1807 )
1808
1809 self.assertEqual(HTTPStatus.BAD_REQUEST, channel.code, msg=channel.json_body)
18061810 self.assertEqual(Codes.MISSING_PARAM, channel.json_body["errcode"])
18071811
1808 def test_local_user_does_not_exist(self):
1809 """
1810 Tests that a lookup for a user that does not exist returns a 404
1811 """
1812 body = json.dumps({"user_id": "@unknown:test"})
1812 def test_local_user_does_not_exist(self) -> None:
1813 """
1814 Tests that a lookup for a user that does not exist returns a HTTPStatus.NOT_FOUND
1815 """
18131816
18141817 channel = self.make_request(
18151818 "POST",
18161819 self.url,
1817 content=body,
1818 access_token=self.admin_user_tok,
1819 )
1820
1821 self.assertEqual(404, channel.code, msg=channel.json_body)
1820 content={"user_id": "@unknown:test"},
1821 access_token=self.admin_user_tok,
1822 )
1823
1824 self.assertEqual(HTTPStatus.NOT_FOUND, channel.code, msg=channel.json_body)
18221825 self.assertEqual(Codes.NOT_FOUND, channel.json_body["errcode"])
18231826
1824 def test_remote_user(self):
1827 def test_remote_user(self) -> None:
18251828 """
18261829 Check that only local user can join rooms.
18271830 """
1828 body = json.dumps({"user_id": "@not:exist.bla"})
18291831
18301832 channel = self.make_request(
18311833 "POST",
18321834 self.url,
1833 content=body,
1834 access_token=self.admin_user_tok,
1835 )
1836
1837 self.assertEqual(400, channel.code, msg=channel.json_body)
1835 content={"user_id": "@not:exist.bla"},
1836 access_token=self.admin_user_tok,
1837 )
1838
1839 self.assertEqual(HTTPStatus.BAD_REQUEST, channel.code, msg=channel.json_body)
18381840 self.assertEqual(
18391841 "This endpoint can only be used with local users",
18401842 channel.json_body["error"],
18411843 )
18421844
1843 def test_room_does_not_exist(self):
1844 """
1845 Check that unknown rooms/server return error 404.
1846 """
1847 body = json.dumps({"user_id": self.second_user_id})
1845 def test_room_does_not_exist(self) -> None:
1846 """
1847 Check that unknown rooms/server return error HTTPStatus.NOT_FOUND.
1848 """
18481849 url = "/_synapse/admin/v1/join/!unknown:test"
18491850
18501851 channel = self.make_request(
18511852 "POST",
18521853 url,
1853 content=body,
1854 access_token=self.admin_user_tok,
1855 )
1856
1857 self.assertEqual(404, channel.code, msg=channel.json_body)
1854 content={"user_id": self.second_user_id},
1855 access_token=self.admin_user_tok,
1856 )
1857
1858 self.assertEqual(HTTPStatus.NOT_FOUND, channel.code, msg=channel.json_body)
18581859 self.assertEqual("No known servers", channel.json_body["error"])
18591860
1860 def test_room_is_not_valid(self):
1861 """
1862 Check that invalid room names, return an error 400.
1863 """
1864 body = json.dumps({"user_id": self.second_user_id})
1861 def test_room_is_not_valid(self) -> None:
1862 """
1863 Check that invalid room names, return an error HTTPStatus.BAD_REQUEST.
1864 """
18651865 url = "/_synapse/admin/v1/join/invalidroom"
18661866
18671867 channel = self.make_request(
18681868 "POST",
18691869 url,
1870 content=body,
1871 access_token=self.admin_user_tok,
1872 )
1873
1874 self.assertEqual(400, channel.code, msg=channel.json_body)
1870 content={"user_id": self.second_user_id},
1871 access_token=self.admin_user_tok,
1872 )
1873
1874 self.assertEqual(HTTPStatus.BAD_REQUEST, channel.code, msg=channel.json_body)
18751875 self.assertEqual(
18761876 "invalidroom was not legal room ID or room alias",
18771877 channel.json_body["error"],
18781878 )
18791879
1880 def test_join_public_room(self):
1880 def test_join_public_room(self) -> None:
18811881 """
18821882 Test joining a local user to a public room with "JoinRules.PUBLIC"
18831883 """
1884 body = json.dumps({"user_id": self.second_user_id})
18851884
18861885 channel = self.make_request(
18871886 "POST",
18881887 self.url,
1889 content=body,
1890 access_token=self.admin_user_tok,
1891 )
1892
1893 self.assertEqual(200, channel.code, msg=channel.json_body)
1888 content={"user_id": self.second_user_id},
1889 access_token=self.admin_user_tok,
1890 )
1891
1892 self.assertEqual(HTTPStatus.OK, channel.code, msg=channel.json_body)
18941893 self.assertEqual(self.public_room_id, channel.json_body["room_id"])
18951894
18961895 # Validate if user is a member of the room
19001899 "/_matrix/client/r0/joined_rooms",
19011900 access_token=self.second_tok,
19021901 )
1903 self.assertEquals(200, channel.code, msg=channel.json_body)
1902 self.assertEquals(HTTPStatus.OK, channel.code, msg=channel.json_body)
19041903 self.assertEqual(self.public_room_id, channel.json_body["joined_rooms"][0])
19051904
1906 def test_join_private_room_if_not_member(self):
1905 def test_join_private_room_if_not_member(self) -> None:
19071906 """
19081907 Test joining a local user to a private room with "JoinRules.INVITE"
19091908 when server admin is not member of this room.
19121911 self.creator, tok=self.creator_tok, is_public=False
19131912 )
19141913 url = f"/_synapse/admin/v1/join/{private_room_id}"
1915 body = json.dumps({"user_id": self.second_user_id})
19161914
19171915 channel = self.make_request(
19181916 "POST",
19191917 url,
1920 content=body,
1921 access_token=self.admin_user_tok,
1922 )
1923
1924 self.assertEqual(403, channel.code, msg=channel.json_body)
1918 content={"user_id": self.second_user_id},
1919 access_token=self.admin_user_tok,
1920 )
1921
1922 self.assertEqual(HTTPStatus.FORBIDDEN, channel.code, msg=channel.json_body)
19251923 self.assertEqual(Codes.FORBIDDEN, channel.json_body["errcode"])
19261924
1927 def test_join_private_room_if_member(self):
1925 def test_join_private_room_if_member(self) -> None:
19281926 """
19291927 Test joining a local user to a private room with "JoinRules.INVITE",
19301928 when server admin is member of this room.
19491947 "/_matrix/client/r0/joined_rooms",
19501948 access_token=self.admin_user_tok,
19511949 )
1952 self.assertEquals(200, channel.code, msg=channel.json_body)
1950 self.assertEquals(HTTPStatus.OK, channel.code, msg=channel.json_body)
19531951 self.assertEqual(private_room_id, channel.json_body["joined_rooms"][0])
19541952
19551953 # Join user to room.
19561954
19571955 url = f"/_synapse/admin/v1/join/{private_room_id}"
1958 body = json.dumps({"user_id": self.second_user_id})
19591956
19601957 channel = self.make_request(
19611958 "POST",
19621959 url,
1963 content=body,
1964 access_token=self.admin_user_tok,
1965 )
1966 self.assertEqual(200, channel.code, msg=channel.json_body)
1960 content={"user_id": self.second_user_id},
1961 access_token=self.admin_user_tok,
1962 )
1963 self.assertEqual(HTTPStatus.OK, channel.code, msg=channel.json_body)
19671964 self.assertEqual(private_room_id, channel.json_body["room_id"])
19681965
19691966 # Validate if user is a member of the room
19731970 "/_matrix/client/r0/joined_rooms",
19741971 access_token=self.second_tok,
19751972 )
1976 self.assertEquals(200, channel.code, msg=channel.json_body)
1973 self.assertEquals(HTTPStatus.OK, channel.code, msg=channel.json_body)
19771974 self.assertEqual(private_room_id, channel.json_body["joined_rooms"][0])
19781975
1979 def test_join_private_room_if_owner(self):
1976 def test_join_private_room_if_owner(self) -> None:
19801977 """
19811978 Test joining a local user to a private room with "JoinRules.INVITE",
19821979 when server admin is owner of this room.
19851982 self.admin_user, tok=self.admin_user_tok, is_public=False
19861983 )
19871984 url = f"/_synapse/admin/v1/join/{private_room_id}"
1988 body = json.dumps({"user_id": self.second_user_id})
19891985
19901986 channel = self.make_request(
19911987 "POST",
19921988 url,
1993 content=body,
1994 access_token=self.admin_user_tok,
1995 )
1996
1997 self.assertEqual(200, channel.code, msg=channel.json_body)
1989 content={"user_id": self.second_user_id},
1990 access_token=self.admin_user_tok,
1991 )
1992
1993 self.assertEqual(HTTPStatus.OK, channel.code, msg=channel.json_body)
19981994 self.assertEqual(private_room_id, channel.json_body["room_id"])
19991995
20001996 # Validate if user is a member of the room
20042000 "/_matrix/client/r0/joined_rooms",
20052001 access_token=self.second_tok,
20062002 )
2007 self.assertEquals(200, channel.code, msg=channel.json_body)
2003 self.assertEquals(HTTPStatus.OK, channel.code, msg=channel.json_body)
20082004 self.assertEqual(private_room_id, channel.json_body["joined_rooms"][0])
20092005
2010 def test_context_as_non_admin(self):
2006 def test_context_as_non_admin(self) -> None:
20112007 """
20122008 Test that, without being admin, one cannot use the context admin API
20132009 """
20382034 % (room_id, events[midway]["event_id"]),
20392035 access_token=tok,
20402036 )
2041 self.assertEquals(403, channel.code, msg=channel.json_body)
2037 self.assertEquals(HTTPStatus.FORBIDDEN, channel.code, msg=channel.json_body)
20422038 self.assertEqual(Codes.FORBIDDEN, channel.json_body["errcode"])
20432039
2044 def test_context_as_admin(self):
2040 def test_context_as_admin(self) -> None:
20452041 """
20462042 Test that, as admin, we can find the context of an event without having joined the room.
20472043 """
20682064 % (room_id, events[midway]["event_id"]),
20692065 access_token=self.admin_user_tok,
20702066 )
2071 self.assertEquals(200, channel.code, msg=channel.json_body)
2067 self.assertEquals(HTTPStatus.OK, channel.code, msg=channel.json_body)
20722068 self.assertEquals(
20732069 channel.json_body["event"]["event_id"], events[midway]["event_id"]
20742070 )
20972093 login.register_servlets,
20982094 ]
20992095
2100 def prepare(self, reactor, clock, homeserver):
2096 def prepare(self, reactor: MemoryReactor, clock: Clock, hs: HomeServer) -> None:
21012097 self.admin_user = self.register_user("admin", "pass", admin=True)
21022098 self.admin_user_tok = self.login("admin", "pass")
21032099
21142110 self.public_room_id
21152111 )
21162112
2117 def test_public_room(self):
2113 def test_public_room(self) -> None:
21182114 """Test that getting admin in a public room works."""
21192115 room_id = self.helper.create_room_as(
21202116 self.creator, tok=self.creator_tok, is_public=True
21272123 access_token=self.admin_user_tok,
21282124 )
21292125
2130 self.assertEqual(200, channel.code, msg=channel.json_body)
2126 self.assertEqual(HTTPStatus.OK, channel.code, msg=channel.json_body)
21312127
21322128 # Now we test that we can join the room and ban a user.
21332129 self.helper.join(room_id, self.admin_user, tok=self.admin_user_tok)
21392135 tok=self.admin_user_tok,
21402136 )
21412137
2142 def test_private_room(self):
2138 def test_private_room(self) -> None:
21432139 """Test that getting admin in a private room works and we get invited."""
21442140 room_id = self.helper.create_room_as(
21452141 self.creator,
21542150 access_token=self.admin_user_tok,
21552151 )
21562152
2157 self.assertEqual(200, channel.code, msg=channel.json_body)
2153 self.assertEqual(HTTPStatus.OK, channel.code, msg=channel.json_body)
21582154
21592155 # Now we test that we can join the room (we should have received an
21602156 # invite) and can ban a user.
21672163 tok=self.admin_user_tok,
21682164 )
21692165
2170 def test_other_user(self):
2166 def test_other_user(self) -> None:
21712167 """Test that giving admin in a public room works to a non-admin user works."""
21722168 room_id = self.helper.create_room_as(
21732169 self.creator, tok=self.creator_tok, is_public=True
21802176 access_token=self.admin_user_tok,
21812177 )
21822178
2183 self.assertEqual(200, channel.code, msg=channel.json_body)
2179 self.assertEqual(HTTPStatus.OK, channel.code, msg=channel.json_body)
21842180
21852181 # Now we test that we can join the room and ban a user.
21862182 self.helper.join(room_id, self.second_user_id, tok=self.second_tok)
21922188 tok=self.second_tok,
21932189 )
21942190
2195 def test_not_enough_power(self):
2191 def test_not_enough_power(self) -> None:
21962192 """Test that we get a sensible error if there are no local room admins."""
21972193 room_id = self.helper.create_room_as(
21982194 self.creator, tok=self.creator_tok, is_public=True
22142210 access_token=self.admin_user_tok,
22152211 )
22162212
2217 # We expect this to fail with a 400 as there are no room admins.
2213 # We expect this to fail with a HTTPStatus.BAD_REQUEST as there are no room admins.
22182214 #
22192215 # (Note we assert the error message to ensure that it's not denied for
22202216 # some other reason)
2221 self.assertEqual(400, channel.code, msg=channel.json_body)
2217 self.assertEqual(HTTPStatus.BAD_REQUEST, channel.code, msg=channel.json_body)
22222218 self.assertEqual(
22232219 channel.json_body["error"],
22242220 "No local admin user in room with power to update power levels.",
22322228 login.register_servlets,
22332229 ]
22342230
2235 def prepare(self, reactor, clock, hs):
2231 def prepare(self, reactor: MemoryReactor, clock: Clock, hs: HomeServer) -> None:
22362232 self._store = hs.get_datastore()
22372233
22382234 self.admin_user = self.register_user("admin", "pass", admin=True)
22472243 self.url = "/_synapse/admin/v1/rooms/%s/block"
22482244
22492245 @parameterized.expand([("PUT",), ("GET",)])
2250 def test_requester_is_no_admin(self, method: str):
2251 """If the user is not a server admin, an error 403 is returned."""
2246 def test_requester_is_no_admin(self, method: str) -> None:
2247 """If the user is not a server admin, an error HTTPStatus.FORBIDDEN is returned."""
22522248
22532249 channel = self.make_request(
22542250 method,
22612257 self.assertEqual(Codes.FORBIDDEN, channel.json_body["errcode"])
22622258
22632259 @parameterized.expand([("PUT",), ("GET",)])
2264 def test_room_is_not_valid(self, method: str):
2265 """Check that invalid room names, return an error 400."""
2260 def test_room_is_not_valid(self, method: str) -> None:
2261 """Check that invalid room names, return an error HTTPStatus.BAD_REQUEST."""
22662262
22672263 channel = self.make_request(
22682264 method,
22772273 channel.json_body["error"],
22782274 )
22792275
2280 def test_block_is_not_valid(self):
2276 def test_block_is_not_valid(self) -> None:
22812277 """If parameter `block` is not valid, return an error."""
22822278
22832279 # `block` is not valid
23122308 self.assertEqual(HTTPStatus.BAD_REQUEST, channel.code, msg=channel.json_body)
23132309 self.assertEqual(Codes.NOT_JSON, channel.json_body["errcode"])
23142310
2315 def test_block_room(self):
2311 def test_block_room(self) -> None:
23162312 """Test that block a room is successful."""
23172313
23182314 def _request_and_test_block_room(room_id: str) -> None:
23362332 # unknown remote room
23372333 _request_and_test_block_room("!unknown:remote")
23382334
2339 def test_block_room_twice(self):
2335 def test_block_room_twice(self) -> None:
23402336 """Test that block a room that is already blocked is successful."""
23412337
23422338 self._is_blocked(self.room_id, expect=False)
23512347 self.assertTrue(channel.json_body["block"])
23522348 self._is_blocked(self.room_id, expect=True)
23532349
2354 def test_unblock_room(self):
2350 def test_unblock_room(self) -> None:
23552351 """Test that unblock a room is successful."""
23562352
23572353 def _request_and_test_unblock_room(room_id: str) -> None:
23762372 # unknown remote room
23772373 _request_and_test_unblock_room("!unknown:remote")
23782374
2379 def test_unblock_room_twice(self):
2375 def test_unblock_room_twice(self) -> None:
23802376 """Test that unblock a room that is not blocked is successful."""
23812377
23822378 self._block_room(self.room_id)
23912387 self.assertFalse(channel.json_body["block"])
23922388 self._is_blocked(self.room_id, expect=False)
23932389
2394 def test_get_blocked_room(self):
2390 def test_get_blocked_room(self) -> None:
23952391 """Test get status of a blocked room"""
23962392
23972393 def _request_blocked_room(room_id: str) -> None:
24152411 # unknown remote room
24162412 _request_blocked_room("!unknown:remote")
24172413
2418 def test_get_unblocked_room(self):
2414 def test_get_unblocked_room(self) -> None:
24192415 """Test get status of a unblocked room"""
24202416
24212417 def _request_unblocked_room(room_id: str) -> None:
1010 # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
1111 # See the License for the specific language governing permissions and
1212 # limitations under the License.
13
13 from http import HTTPStatus
1414 from typing import List
15
16 from twisted.test.proto_helpers import MemoryReactor
1517
1618 import synapse.rest.admin
1719 from synapse.api.errors import Codes
1820 from synapse.rest.client import login, room, sync
21 from synapse.server import HomeServer
1922 from synapse.storage.roommember import RoomsForUser
2023 from synapse.types import JsonDict
24 from synapse.util import Clock
2125
2226 from tests import unittest
2327 from tests.unittest import override_config
3236 sync.register_servlets,
3337 ]
3438
35 def prepare(self, reactor, clock, hs):
39 def prepare(self, reactor: MemoryReactor, clock: Clock, hs: HomeServer) -> None:
3640 self.store = hs.get_datastore()
3741 self.room_shutdown_handler = hs.get_room_shutdown_handler()
3842 self.pagination_handler = hs.get_pagination_handler()
4751
4852 self.url = "/_synapse/admin/v1/send_server_notice"
4953
50 def test_no_auth(self):
54 def test_no_auth(self) -> None:
5155 """Try to send a server notice without authentication."""
5256 channel = self.make_request("POST", self.url)
5357
54 self.assertEqual(401, int(channel.result["code"]), msg=channel.result["body"])
58 self.assertEqual(
59 HTTPStatus.UNAUTHORIZED,
60 channel.code,
61 msg=channel.json_body,
62 )
5563 self.assertEqual(Codes.MISSING_TOKEN, channel.json_body["errcode"])
5664
57 def test_requester_is_no_admin(self):
65 def test_requester_is_no_admin(self) -> None:
5866 """If the user is not a server admin, an error is returned."""
5967 channel = self.make_request(
6068 "POST",
6270 access_token=self.other_user_token,
6371 )
6472
65 self.assertEqual(403, int(channel.result["code"]), msg=channel.result["body"])
73 self.assertEqual(
74 HTTPStatus.FORBIDDEN,
75 channel.code,
76 msg=channel.json_body,
77 )
6678 self.assertEqual(Codes.FORBIDDEN, channel.json_body["errcode"])
6779
6880 @override_config({"server_notices": {"system_mxid_localpart": "notices"}})
69 def test_user_does_not_exist(self):
70 """Tests that a lookup for a user that does not exist returns a 404"""
81 def test_user_does_not_exist(self) -> None:
82 """Tests that a lookup for a user that does not exist returns a HTTPStatus.NOT_FOUND"""
7183 channel = self.make_request(
7284 "POST",
7385 self.url,
7587 content={"user_id": "@unknown_person:test", "content": ""},
7688 )
7789
78 self.assertEqual(404, channel.code, msg=channel.json_body)
90 self.assertEqual(HTTPStatus.NOT_FOUND, channel.code, msg=channel.json_body)
7991 self.assertEqual(Codes.NOT_FOUND, channel.json_body["errcode"])
8092
8193 @override_config({"server_notices": {"system_mxid_localpart": "notices"}})
82 def test_user_is_not_local(self):
83 """
84 Tests that a lookup for a user that is not a local returns a 400
94 def test_user_is_not_local(self) -> None:
95 """
96 Tests that a lookup for a user that is not a local returns a HTTPStatus.BAD_REQUEST
8597 """
8698 channel = self.make_request(
8799 "POST",
93105 },
94106 )
95107
96 self.assertEqual(400, channel.code, msg=channel.json_body)
108 self.assertEqual(HTTPStatus.BAD_REQUEST, channel.code, msg=channel.json_body)
97109 self.assertEqual(
98110 "Server notices can only be sent to local users", channel.json_body["error"]
99111 )
100112
101113 @override_config({"server_notices": {"system_mxid_localpart": "notices"}})
102 def test_invalid_parameter(self):
114 def test_invalid_parameter(self) -> None:
103115 """If parameters are invalid, an error is returned."""
104116
105117 # no content, no user
109121 access_token=self.admin_user_tok,
110122 )
111123
112 self.assertEqual(400, channel.code, msg=channel.json_body)
124 self.assertEqual(HTTPStatus.BAD_REQUEST, channel.code, msg=channel.json_body)
113125 self.assertEqual(Codes.NOT_JSON, channel.json_body["errcode"])
114126
115127 # no content
120132 content={"user_id": self.other_user},
121133 )
122134
123 self.assertEqual(400, channel.code, msg=channel.json_body)
135 self.assertEqual(HTTPStatus.BAD_REQUEST, channel.code, msg=channel.json_body)
124136 self.assertEqual(Codes.MISSING_PARAM, channel.json_body["errcode"])
125137
126138 # no body
131143 content={"user_id": self.other_user, "content": ""},
132144 )
133145
134 self.assertEqual(400, channel.code, msg=channel.json_body)
146 self.assertEqual(HTTPStatus.BAD_REQUEST, channel.code, msg=channel.json_body)
135147 self.assertEqual(Codes.UNKNOWN, channel.json_body["errcode"])
136148 self.assertEqual("'body' not in content", channel.json_body["error"])
137149
143155 content={"user_id": self.other_user, "content": {"body": ""}},
144156 )
145157
146 self.assertEqual(400, channel.code, msg=channel.json_body)
158 self.assertEqual(HTTPStatus.BAD_REQUEST, channel.code, msg=channel.json_body)
147159 self.assertEqual(Codes.UNKNOWN, channel.json_body["errcode"])
148160 self.assertEqual("'msgtype' not in content", channel.json_body["error"])
149161
150 def test_server_notice_disabled(self):
162 def test_server_notice_disabled(self) -> None:
151163 """Tests that server returns error if server notice is disabled"""
152164 channel = self.make_request(
153165 "POST",
159171 },
160172 )
161173
162 self.assertEqual(400, channel.code, msg=channel.json_body)
174 self.assertEqual(HTTPStatus.BAD_REQUEST, channel.code, msg=channel.json_body)
163175 self.assertEqual(Codes.UNKNOWN, channel.json_body["errcode"])
164176 self.assertEqual(
165177 "Server notices are not enabled on this server", channel.json_body["error"]
166178 )
167179
168180 @override_config({"server_notices": {"system_mxid_localpart": "notices"}})
169 def test_send_server_notice(self):
181 def test_send_server_notice(self) -> None:
170182 """
171183 Tests that sending two server notices is successfully,
172184 the server uses the same room and do not send messages twice.
184196 "content": {"msgtype": "m.text", "body": "test msg one"},
185197 },
186198 )
187 self.assertEqual(200, channel.code, msg=channel.json_body)
199 self.assertEqual(HTTPStatus.OK, channel.code, msg=channel.json_body)
188200
189201 # user has one invite
190202 invited_rooms = self._check_invite_and_join_status(self.other_user, 1, 0)
215227 "content": {"msgtype": "m.text", "body": "test msg two"},
216228 },
217229 )
218 self.assertEqual(200, channel.code, msg=channel.json_body)
230 self.assertEqual(HTTPStatus.OK, channel.code, msg=channel.json_body)
219231
220232 # user has no new invites or memberships
221233 self._check_invite_and_join_status(self.other_user, 0, 1)
230242 self.assertEqual(messages[1]["sender"], "@notices:test")
231243
232244 @override_config({"server_notices": {"system_mxid_localpart": "notices"}})
233 def test_send_server_notice_leave_room(self):
245 def test_send_server_notice_leave_room(self) -> None:
234246 """
235247 Tests that sending a server notices is successfully.
236248 The user leaves the room and the second message appears
249261 "content": {"msgtype": "m.text", "body": "test msg one"},
250262 },
251263 )
252 self.assertEqual(200, channel.code, msg=channel.json_body)
264 self.assertEqual(HTTPStatus.OK, channel.code, msg=channel.json_body)
253265
254266 # user has one invite
255267 invited_rooms = self._check_invite_and_join_status(self.other_user, 1, 0)
292304 "content": {"msgtype": "m.text", "body": "test msg two"},
293305 },
294306 )
295 self.assertEqual(200, channel.code, msg=channel.json_body)
307 self.assertEqual(HTTPStatus.OK, channel.code, msg=channel.json_body)
296308
297309 # user has one invite
298310 invited_rooms = self._check_invite_and_join_status(self.other_user, 1, 0)
314326 self.assertNotEqual(first_room_id, second_room_id)
315327
316328 @override_config({"server_notices": {"system_mxid_localpart": "notices"}})
317 def test_send_server_notice_delete_room(self):
329 def test_send_server_notice_delete_room(self) -> None:
318330 """
319331 Tests that the user get server notice in a new room
320332 after the first server notice room was deleted.
332344 "content": {"msgtype": "m.text", "body": "test msg one"},
333345 },
334346 )
335 self.assertEqual(200, channel.code, msg=channel.json_body)
347 self.assertEqual(HTTPStatus.OK, channel.code, msg=channel.json_body)
336348
337349 # user has one invite
338350 invited_rooms = self._check_invite_and_join_status(self.other_user, 1, 0)
381393 "content": {"msgtype": "m.text", "body": "test msg two"},
382394 },
383395 )
384 self.assertEqual(200, channel.code, msg=channel.json_body)
396 self.assertEqual(HTTPStatus.OK, channel.code, msg=channel.json_body)
385397
386398 # user has one invite
387399 invited_rooms = self._check_invite_and_join_status(self.other_user, 1, 0)
404416
405417 def _check_invite_and_join_status(
406418 self, user_id: str, expected_invites: int, expected_memberships: int
407 ) -> RoomsForUser:
419 ) -> List[RoomsForUser]:
408420 """Check invite and room membership status of a user.
409421
410422 Args
439451 channel = self.make_request(
440452 "GET", "/_matrix/client/r0/sync", access_token=token
441453 )
442 self.assertEqual(channel.code, 200)
454 self.assertEqual(channel.code, HTTPStatus.OK)
443455
444456 # Get the messages
445457 room = channel.json_body["rooms"]["join"][room_id]
1111 # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
1212 # See the License for the specific language governing permissions and
1313 # limitations under the License.
14
15 import json
16 from typing import Any, Dict, List, Optional
14 from http import HTTPStatus
15 from typing import List, Optional
16
17 from twisted.test.proto_helpers import MemoryReactor
1718
1819 import synapse.rest.admin
1920 from synapse.api.errors import Codes
2021 from synapse.rest.client import login
22 from synapse.server import HomeServer
23 from synapse.types import JsonDict
24 from synapse.util import Clock
2125
2226 from tests import unittest
2327 from tests.test_utils import SMALL_PNG
2933 login.register_servlets,
3034 ]
3135
32 def prepare(self, reactor, clock, hs):
36 def prepare(self, reactor: MemoryReactor, clock: Clock, hs: HomeServer) -> None:
3337 self.media_repo = hs.get_media_repository_resource()
3438
3539 self.admin_user = self.register_user("admin", "pass", admin=True)
4044
4145 self.url = "/_synapse/admin/v1/statistics/users/media"
4246
43 def test_no_auth(self):
47 def test_no_auth(self) -> None:
4448 """
4549 Try to list users without authentication.
4650 """
4751 channel = self.make_request("GET", self.url, b"{}")
4852
49 self.assertEqual(401, int(channel.result["code"]), msg=channel.result["body"])
53 self.assertEqual(
54 HTTPStatus.UNAUTHORIZED,
55 channel.code,
56 msg=channel.json_body,
57 )
5058 self.assertEqual(Codes.MISSING_TOKEN, channel.json_body["errcode"])
5159
52 def test_requester_is_no_admin(self):
53 """
54 If the user is not a server admin, an error 403 is returned.
60 def test_requester_is_no_admin(self) -> None:
61 """
62 If the user is not a server admin, an error HTTPStatus.FORBIDDEN is returned.
5563 """
5664 channel = self.make_request(
5765 "GET",
5866 self.url,
59 json.dumps({}),
67 {},
6068 access_token=self.other_user_tok,
6169 )
6270
63 self.assertEqual(403, int(channel.result["code"]), msg=channel.result["body"])
71 self.assertEqual(
72 HTTPStatus.FORBIDDEN,
73 channel.code,
74 msg=channel.json_body,
75 )
6476 self.assertEqual(Codes.FORBIDDEN, channel.json_body["errcode"])
6577
66 def test_invalid_parameter(self):
78 def test_invalid_parameter(self) -> None:
6779 """
6880 If parameters are invalid, an error is returned.
6981 """
7486 access_token=self.admin_user_tok,
7587 )
7688
77 self.assertEqual(400, int(channel.result["code"]), msg=channel.result["body"])
89 self.assertEqual(
90 HTTPStatus.BAD_REQUEST,
91 channel.code,
92 msg=channel.json_body,
93 )
7894 self.assertEqual(Codes.INVALID_PARAM, channel.json_body["errcode"])
7995
8096 # negative from
84100 access_token=self.admin_user_tok,
85101 )
86102
87 self.assertEqual(400, int(channel.result["code"]), msg=channel.result["body"])
103 self.assertEqual(
104 HTTPStatus.BAD_REQUEST,
105 channel.code,
106 msg=channel.json_body,
107 )
88108 self.assertEqual(Codes.INVALID_PARAM, channel.json_body["errcode"])
89109
90110 # negative limit
94114 access_token=self.admin_user_tok,
95115 )
96116
97 self.assertEqual(400, int(channel.result["code"]), msg=channel.result["body"])
117 self.assertEqual(
118 HTTPStatus.BAD_REQUEST,
119 channel.code,
120 msg=channel.json_body,
121 )
98122 self.assertEqual(Codes.INVALID_PARAM, channel.json_body["errcode"])
99123
100124 # negative from_ts
104128 access_token=self.admin_user_tok,
105129 )
106130
107 self.assertEqual(400, int(channel.result["code"]), msg=channel.result["body"])
131 self.assertEqual(
132 HTTPStatus.BAD_REQUEST,
133 channel.code,
134 msg=channel.json_body,
135 )
108136 self.assertEqual(Codes.INVALID_PARAM, channel.json_body["errcode"])
109137
110138 # negative until_ts
114142 access_token=self.admin_user_tok,
115143 )
116144
117 self.assertEqual(400, int(channel.result["code"]), msg=channel.result["body"])
145 self.assertEqual(
146 HTTPStatus.BAD_REQUEST,
147 channel.code,
148 msg=channel.json_body,
149 )
118150 self.assertEqual(Codes.INVALID_PARAM, channel.json_body["errcode"])
119151
120152 # until_ts smaller from_ts
124156 access_token=self.admin_user_tok,
125157 )
126158
127 self.assertEqual(400, int(channel.result["code"]), msg=channel.result["body"])
159 self.assertEqual(
160 HTTPStatus.BAD_REQUEST,
161 channel.code,
162 msg=channel.json_body,
163 )
128164 self.assertEqual(Codes.INVALID_PARAM, channel.json_body["errcode"])
129165
130166 # empty search term
134170 access_token=self.admin_user_tok,
135171 )
136172
137 self.assertEqual(400, int(channel.result["code"]), msg=channel.result["body"])
173 self.assertEqual(
174 HTTPStatus.BAD_REQUEST,
175 channel.code,
176 msg=channel.json_body,
177 )
138178 self.assertEqual(Codes.INVALID_PARAM, channel.json_body["errcode"])
139179
140180 # invalid search order
144184 access_token=self.admin_user_tok,
145185 )
146186
147 self.assertEqual(400, int(channel.result["code"]), msg=channel.result["body"])
148 self.assertEqual(Codes.INVALID_PARAM, channel.json_body["errcode"])
149
150 def test_limit(self):
187 self.assertEqual(
188 HTTPStatus.BAD_REQUEST,
189 channel.code,
190 msg=channel.json_body,
191 )
192 self.assertEqual(Codes.INVALID_PARAM, channel.json_body["errcode"])
193
194 def test_limit(self) -> None:
151195 """
152196 Testing list of media with limit
153197 """
159203 access_token=self.admin_user_tok,
160204 )
161205
162 self.assertEqual(200, int(channel.result["code"]), msg=channel.result["body"])
206 self.assertEqual(HTTPStatus.OK, channel.code, msg=channel.json_body)
163207 self.assertEqual(channel.json_body["total"], 10)
164208 self.assertEqual(len(channel.json_body["users"]), 5)
165209 self.assertEqual(channel.json_body["next_token"], 5)
166210 self._check_fields(channel.json_body["users"])
167211
168 def test_from(self):
212 def test_from(self) -> None:
169213 """
170214 Testing list of media with a defined starting point (from)
171215 """
177221 access_token=self.admin_user_tok,
178222 )
179223
180 self.assertEqual(200, int(channel.result["code"]), msg=channel.result["body"])
224 self.assertEqual(HTTPStatus.OK, channel.code, msg=channel.json_body)
181225 self.assertEqual(channel.json_body["total"], 20)
182226 self.assertEqual(len(channel.json_body["users"]), 15)
183227 self.assertNotIn("next_token", channel.json_body)
184228 self._check_fields(channel.json_body["users"])
185229
186 def test_limit_and_from(self):
230 def test_limit_and_from(self) -> None:
187231 """
188232 Testing list of media with a defined starting point and limit
189233 """
195239 access_token=self.admin_user_tok,
196240 )
197241
198 self.assertEqual(200, int(channel.result["code"]), msg=channel.result["body"])
242 self.assertEqual(HTTPStatus.OK, channel.code, msg=channel.json_body)
199243 self.assertEqual(channel.json_body["total"], 20)
200244 self.assertEqual(channel.json_body["next_token"], 15)
201245 self.assertEqual(len(channel.json_body["users"]), 10)
202246 self._check_fields(channel.json_body["users"])
203247
204 def test_next_token(self):
248 def test_next_token(self) -> None:
205249 """
206250 Testing that `next_token` appears at the right place
207251 """
217261 access_token=self.admin_user_tok,
218262 )
219263
220 self.assertEqual(200, int(channel.result["code"]), msg=channel.result["body"])
264 self.assertEqual(HTTPStatus.OK, channel.code, msg=channel.json_body)
221265 self.assertEqual(channel.json_body["total"], number_users)
222266 self.assertEqual(len(channel.json_body["users"]), number_users)
223267 self.assertNotIn("next_token", channel.json_body)
230274 access_token=self.admin_user_tok,
231275 )
232276
233 self.assertEqual(200, int(channel.result["code"]), msg=channel.result["body"])
277 self.assertEqual(HTTPStatus.OK, channel.code, msg=channel.json_body)
234278 self.assertEqual(channel.json_body["total"], number_users)
235279 self.assertEqual(len(channel.json_body["users"]), number_users)
236280 self.assertNotIn("next_token", channel.json_body)
243287 access_token=self.admin_user_tok,
244288 )
245289
246 self.assertEqual(200, int(channel.result["code"]), msg=channel.result["body"])
290 self.assertEqual(HTTPStatus.OK, channel.code, msg=channel.json_body)
247291 self.assertEqual(channel.json_body["total"], number_users)
248292 self.assertEqual(len(channel.json_body["users"]), 19)
249293 self.assertEqual(channel.json_body["next_token"], 19)
256300 access_token=self.admin_user_tok,
257301 )
258302
259 self.assertEqual(200, int(channel.result["code"]), msg=channel.result["body"])
303 self.assertEqual(HTTPStatus.OK, channel.code, msg=channel.json_body)
260304 self.assertEqual(channel.json_body["total"], number_users)
261305 self.assertEqual(len(channel.json_body["users"]), 1)
262306 self.assertNotIn("next_token", channel.json_body)
263307
264 def test_no_media(self):
308 def test_no_media(self) -> None:
265309 """
266310 Tests that a normal lookup for statistics is successfully
267311 if users have no media created
273317 access_token=self.admin_user_tok,
274318 )
275319
276 self.assertEqual(200, channel.code, msg=channel.json_body)
320 self.assertEqual(HTTPStatus.OK, channel.code, msg=channel.json_body)
277321 self.assertEqual(0, channel.json_body["total"])
278322 self.assertEqual(0, len(channel.json_body["users"]))
279323
280 def test_order_by(self):
324 def test_order_by(self) -> None:
281325 """
282326 Testing order list with parameter `order_by`
283327 """
355399 "b",
356400 )
357401
358 def test_from_until_ts(self):
402 def test_from_until_ts(self) -> None:
359403 """
360404 Testing filter by time with parameters `from_ts` and `until_ts`
361405 """
370414 self.url,
371415 access_token=self.admin_user_tok,
372416 )
373 self.assertEqual(200, int(channel.result["code"]), msg=channel.result["body"])
417 self.assertEqual(HTTPStatus.OK, channel.code, msg=channel.json_body)
374418 self.assertEqual(channel.json_body["users"][0]["media_count"], 3)
375419
376420 # filter media starting at `ts1` after creating first media
380424 self.url + "?from_ts=%s" % (ts1,),
381425 access_token=self.admin_user_tok,
382426 )
383 self.assertEqual(200, int(channel.result["code"]), msg=channel.result["body"])
427 self.assertEqual(HTTPStatus.OK, channel.code, msg=channel.json_body)
384428 self.assertEqual(channel.json_body["total"], 0)
385429
386430 self._create_media(self.other_user_tok, 3)
395439 self.url + "?from_ts=%s&until_ts=%s" % (ts1, ts2),
396440 access_token=self.admin_user_tok,
397441 )
398 self.assertEqual(200, int(channel.result["code"]), msg=channel.result["body"])
442 self.assertEqual(HTTPStatus.OK, channel.code, msg=channel.json_body)
399443 self.assertEqual(channel.json_body["users"][0]["media_count"], 3)
400444
401445 # filter media until `ts2` and earlier
404448 self.url + "?until_ts=%s" % (ts2,),
405449 access_token=self.admin_user_tok,
406450 )
407 self.assertEqual(200, int(channel.result["code"]), msg=channel.result["body"])
451 self.assertEqual(HTTPStatus.OK, channel.code, msg=channel.json_body)
408452 self.assertEqual(channel.json_body["users"][0]["media_count"], 6)
409453
410 def test_search_term(self):
454 def test_search_term(self) -> None:
411455 self._create_users_with_media(20, 1)
412456
413457 # check without filter get all users
416460 self.url,
417461 access_token=self.admin_user_tok,
418462 )
419 self.assertEqual(200, int(channel.result["code"]), msg=channel.result["body"])
463 self.assertEqual(HTTPStatus.OK, channel.code, msg=channel.json_body)
420464 self.assertEqual(channel.json_body["total"], 20)
421465
422466 # filter user 1 and 10-19 by `user_id`
425469 self.url + "?search_term=foo_user_1",
426470 access_token=self.admin_user_tok,
427471 )
428 self.assertEqual(200, int(channel.result["code"]), msg=channel.result["body"])
472 self.assertEqual(HTTPStatus.OK, channel.code, msg=channel.json_body)
429473 self.assertEqual(channel.json_body["total"], 11)
430474
431475 # filter on this user in `displayname`
434478 self.url + "?search_term=bar_user_10",
435479 access_token=self.admin_user_tok,
436480 )
437 self.assertEqual(200, int(channel.result["code"]), msg=channel.result["body"])
481 self.assertEqual(HTTPStatus.OK, channel.code, msg=channel.json_body)
438482 self.assertEqual(channel.json_body["users"][0]["displayname"], "bar_user_10")
439483 self.assertEqual(channel.json_body["total"], 1)
440484
444488 self.url + "?search_term=foobar",
445489 access_token=self.admin_user_tok,
446490 )
447 self.assertEqual(200, int(channel.result["code"]), msg=channel.result["body"])
491 self.assertEqual(HTTPStatus.OK, channel.code, msg=channel.json_body)
448492 self.assertEqual(channel.json_body["total"], 0)
449493
450 def _create_users_with_media(self, number_users: int, media_per_user: int):
494 def _create_users_with_media(self, number_users: int, media_per_user: int) -> None:
451495 """
452496 Create a number of users with a number of media
453497 Args:
459503 user_tok = self.login("foo_user_%s" % i, "pass")
460504 self._create_media(user_tok, media_per_user)
461505
462 def _create_media(self, user_token: str, number_media: int):
506 def _create_media(self, user_token: str, number_media: int) -> None:
463507 """
464508 Create a number of media for a specific user
465509 Args:
470514 for _ in range(number_media):
471515 # Upload some media into the room
472516 self.helper.upload_media(
473 upload_resource, SMALL_PNG, tok=user_token, expect_code=200
517 upload_resource, SMALL_PNG, tok=user_token, expect_code=HTTPStatus.OK
474518 )
475519
476 def _check_fields(self, content: List[Dict[str, Any]]):
520 def _check_fields(self, content: List[JsonDict]) -> None:
477521 """Checks that all attributes are present in content
478522 Args:
479523 content: List that is checked for content
486530
487531 def _order_test(
488532 self, order_type: str, expected_user_list: List[str], dir: Optional[str] = None
489 ):
533 ) -> None:
490534 """Request the list of users in a certain order. Assert that order is what
491535 we expect
492536 Args:
504548 url.encode("ascii"),
505549 access_token=self.admin_user_tok,
506550 )
507 self.assertEqual(200, channel.code, msg=channel.json_body)
551 self.assertEqual(HTTPStatus.OK, channel.code, msg=channel.json_body)
508552 self.assertEqual(channel.json_body["total"], len(expected_user_list))
509553
510554 returned_order = [row["user_id"] for row in channel.json_body["users"]]
1616 import os
1717 import urllib.parse
1818 from binascii import unhexlify
19 from http import HTTPStatus
1920 from typing import List, Optional
2021 from unittest.mock import Mock, patch
2122
7374
7475 channel = self.make_request("POST", self.url, b"{}")
7576
76 self.assertEqual(400, channel.code, msg=channel.json_body)
77 self.assertEqual(HTTPStatus.BAD_REQUEST, channel.code, msg=channel.json_body)
7778 self.assertEqual(
7879 "Shared secret registration is not enabled", channel.json_body["error"]
7980 )
105106 body = {"nonce": nonce}
106107 channel = self.make_request("POST", self.url, body)
107108
108 self.assertEqual(400, channel.code, msg=channel.json_body)
109 self.assertEqual(HTTPStatus.BAD_REQUEST, channel.code, msg=channel.json_body)
109110 self.assertEqual("username must be specified", channel.json_body["error"])
110111
111112 # 61 seconds
113114
114115 channel = self.make_request("POST", self.url, body)
115116
116 self.assertEqual(400, channel.code, msg=channel.json_body)
117 self.assertEqual(HTTPStatus.BAD_REQUEST, channel.code, msg=channel.json_body)
117118 self.assertEqual("unrecognised nonce", channel.json_body["error"])
118119
119120 def test_register_incorrect_nonce(self):
125126
126127 want_mac = hmac.new(key=b"shared", digestmod=hashlib.sha1)
127128 want_mac.update(b"notthenonce\x00bob\x00abc123\x00admin")
128 want_mac = want_mac.hexdigest()
129 want_mac_str = want_mac.hexdigest()
129130
130131 body = {
131132 "nonce": nonce,
132133 "username": "bob",
133134 "password": "abc123",
134135 "admin": True,
135 "mac": want_mac,
136 "mac": want_mac_str,
136137 }
137138 channel = self.make_request("POST", self.url, body)
138139
139 self.assertEqual(403, channel.code, msg=channel.json_body)
140 self.assertEqual(HTTPStatus.FORBIDDEN, channel.code, msg=channel.json_body)
140141 self.assertEqual("HMAC incorrect", channel.json_body["error"])
141142
142143 def test_register_correct_nonce(self):
151152 want_mac.update(
152153 nonce.encode("ascii") + b"\x00bob\x00abc123\x00admin\x00support"
153154 )
154 want_mac = want_mac.hexdigest()
155 want_mac_str = want_mac.hexdigest()
155156
156157 body = {
157158 "nonce": nonce,
159160 "password": "abc123",
160161 "admin": True,
161162 "user_type": UserTypes.SUPPORT,
162 "mac": want_mac,
163 "mac": want_mac_str,
163164 }
164165 channel = self.make_request("POST", self.url, body)
165166
166 self.assertEqual(200, channel.code, msg=channel.json_body)
167 self.assertEqual(HTTPStatus.OK, channel.code, msg=channel.json_body)
167168 self.assertEqual("@bob:test", channel.json_body["user_id"])
168169
169170 def test_nonce_reuse(self):
175176
176177 want_mac = hmac.new(key=b"shared", digestmod=hashlib.sha1)
177178 want_mac.update(nonce.encode("ascii") + b"\x00bob\x00abc123\x00admin")
178 want_mac = want_mac.hexdigest()
179 want_mac_str = want_mac.hexdigest()
179180
180181 body = {
181182 "nonce": nonce,
182183 "username": "bob",
183184 "password": "abc123",
184185 "admin": True,
185 "mac": want_mac,
186 "mac": want_mac_str,
186187 }
187188 channel = self.make_request("POST", self.url, body)
188189
189 self.assertEqual(200, channel.code, msg=channel.json_body)
190 self.assertEqual(HTTPStatus.OK, channel.code, msg=channel.json_body)
190191 self.assertEqual("@bob:test", channel.json_body["user_id"])
191192
192193 # Now, try and reuse it
193194 channel = self.make_request("POST", self.url, body)
194195
195 self.assertEqual(400, channel.code, msg=channel.json_body)
196 self.assertEqual(HTTPStatus.BAD_REQUEST, channel.code, msg=channel.json_body)
196197 self.assertEqual("unrecognised nonce", channel.json_body["error"])
197198
198199 def test_missing_parts(self):
213214 # Must be an empty body present
214215 channel = self.make_request("POST", self.url, {})
215216
216 self.assertEqual(400, channel.code, msg=channel.json_body)
217 self.assertEqual(HTTPStatus.BAD_REQUEST, channel.code, msg=channel.json_body)
217218 self.assertEqual("nonce must be specified", channel.json_body["error"])
218219
219220 #
223224 # Must be present
224225 channel = self.make_request("POST", self.url, {"nonce": nonce()})
225226
226 self.assertEqual(400, channel.code, msg=channel.json_body)
227 self.assertEqual(HTTPStatus.BAD_REQUEST, channel.code, msg=channel.json_body)
227228 self.assertEqual("username must be specified", channel.json_body["error"])
228229
229230 # Must be a string
230231 body = {"nonce": nonce(), "username": 1234}
231232 channel = self.make_request("POST", self.url, body)
232233
233 self.assertEqual(400, channel.code, msg=channel.json_body)
234 self.assertEqual(HTTPStatus.BAD_REQUEST, channel.code, msg=channel.json_body)
234235 self.assertEqual("Invalid username", channel.json_body["error"])
235236
236237 # Must not have null bytes
237238 body = {"nonce": nonce(), "username": "abcd\u0000"}
238239 channel = self.make_request("POST", self.url, body)
239240
240 self.assertEqual(400, channel.code, msg=channel.json_body)
241 self.assertEqual(HTTPStatus.BAD_REQUEST, channel.code, msg=channel.json_body)
241242 self.assertEqual("Invalid username", channel.json_body["error"])
242243
243244 # Must not have null bytes
244245 body = {"nonce": nonce(), "username": "a" * 1000}
245246 channel = self.make_request("POST", self.url, body)
246247
247 self.assertEqual(400, channel.code, msg=channel.json_body)
248 self.assertEqual(HTTPStatus.BAD_REQUEST, channel.code, msg=channel.json_body)
248249 self.assertEqual("Invalid username", channel.json_body["error"])
249250
250251 #
255256 body = {"nonce": nonce(), "username": "a"}
256257 channel = self.make_request("POST", self.url, body)
257258
258 self.assertEqual(400, channel.code, msg=channel.json_body)
259 self.assertEqual(HTTPStatus.BAD_REQUEST, channel.code, msg=channel.json_body)
259260 self.assertEqual("password must be specified", channel.json_body["error"])
260261
261262 # Must be a string
262263 body = {"nonce": nonce(), "username": "a", "password": 1234}
263264 channel = self.make_request("POST", self.url, body)
264265
265 self.assertEqual(400, channel.code, msg=channel.json_body)
266 self.assertEqual(HTTPStatus.BAD_REQUEST, channel.code, msg=channel.json_body)
266267 self.assertEqual("Invalid password", channel.json_body["error"])
267268
268269 # Must not have null bytes
269270 body = {"nonce": nonce(), "username": "a", "password": "abcd\u0000"}
270271 channel = self.make_request("POST", self.url, body)
271272
272 self.assertEqual(400, channel.code, msg=channel.json_body)
273 self.assertEqual(HTTPStatus.BAD_REQUEST, channel.code, msg=channel.json_body)
273274 self.assertEqual("Invalid password", channel.json_body["error"])
274275
275276 # Super long
276277 body = {"nonce": nonce(), "username": "a", "password": "A" * 1000}
277278 channel = self.make_request("POST", self.url, body)
278279
279 self.assertEqual(400, channel.code, msg=channel.json_body)
280 self.assertEqual(HTTPStatus.BAD_REQUEST, channel.code, msg=channel.json_body)
280281 self.assertEqual("Invalid password", channel.json_body["error"])
281282
282283 #
292293 }
293294 channel = self.make_request("POST", self.url, body)
294295
295 self.assertEqual(400, channel.code, msg=channel.json_body)
296 self.assertEqual(HTTPStatus.BAD_REQUEST, channel.code, msg=channel.json_body)
296297 self.assertEqual("Invalid user type", channel.json_body["error"])
297298
298299 def test_displayname(self):
306307
307308 want_mac = hmac.new(key=b"shared", digestmod=hashlib.sha1)
308309 want_mac.update(nonce.encode("ascii") + b"\x00bob1\x00abc123\x00notadmin")
309 want_mac = want_mac.hexdigest()
310 want_mac_str = want_mac.hexdigest()
310311
311312 body = {
312313 "nonce": nonce,
313314 "username": "bob1",
314315 "password": "abc123",
315 "mac": want_mac,
316 "mac": want_mac_str,
316317 }
317318
318319 channel = self.make_request("POST", self.url, body)
319320
320 self.assertEqual(200, channel.code, msg=channel.json_body)
321 self.assertEqual(HTTPStatus.OK, channel.code, msg=channel.json_body)
321322 self.assertEqual("@bob1:test", channel.json_body["user_id"])
322323
323324 channel = self.make_request("GET", "/profile/@bob1:test/displayname")
324 self.assertEqual(200, channel.code, msg=channel.json_body)
325 self.assertEqual(HTTPStatus.OK, channel.code, msg=channel.json_body)
325326 self.assertEqual("bob1", channel.json_body["displayname"])
326327
327328 # displayname is None
330331
331332 want_mac = hmac.new(key=b"shared", digestmod=hashlib.sha1)
332333 want_mac.update(nonce.encode("ascii") + b"\x00bob2\x00abc123\x00notadmin")
333 want_mac = want_mac.hexdigest()
334 want_mac_str = want_mac.hexdigest()
334335
335336 body = {
336337 "nonce": nonce,
337338 "username": "bob2",
338339 "displayname": None,
339340 "password": "abc123",
340 "mac": want_mac,
341 "mac": want_mac_str,
341342 }
342343 channel = self.make_request("POST", self.url, body)
343344
344 self.assertEqual(200, channel.code, msg=channel.json_body)
345 self.assertEqual(HTTPStatus.OK, channel.code, msg=channel.json_body)
345346 self.assertEqual("@bob2:test", channel.json_body["user_id"])
346347
347348 channel = self.make_request("GET", "/profile/@bob2:test/displayname")
348 self.assertEqual(200, channel.code, msg=channel.json_body)
349 self.assertEqual(HTTPStatus.OK, channel.code, msg=channel.json_body)
349350 self.assertEqual("bob2", channel.json_body["displayname"])
350351
351352 # displayname is empty
354355
355356 want_mac = hmac.new(key=b"shared", digestmod=hashlib.sha1)
356357 want_mac.update(nonce.encode("ascii") + b"\x00bob3\x00abc123\x00notadmin")
357 want_mac = want_mac.hexdigest()
358 want_mac_str = want_mac.hexdigest()
358359
359360 body = {
360361 "nonce": nonce,
361362 "username": "bob3",
362363 "displayname": "",
363364 "password": "abc123",
364 "mac": want_mac,
365 "mac": want_mac_str,
365366 }
366367 channel = self.make_request("POST", self.url, body)
367368
368 self.assertEqual(200, channel.code, msg=channel.json_body)
369 self.assertEqual(HTTPStatus.OK, channel.code, msg=channel.json_body)
369370 self.assertEqual("@bob3:test", channel.json_body["user_id"])
370371
371372 channel = self.make_request("GET", "/profile/@bob3:test/displayname")
372 self.assertEqual(404, channel.code, msg=channel.json_body)
373 self.assertEqual(HTTPStatus.NOT_FOUND, channel.code, msg=channel.json_body)
373374
374375 # set displayname
375376 channel = self.make_request("GET", self.url)
377378
378379 want_mac = hmac.new(key=b"shared", digestmod=hashlib.sha1)
379380 want_mac.update(nonce.encode("ascii") + b"\x00bob4\x00abc123\x00notadmin")
380 want_mac = want_mac.hexdigest()
381 want_mac_str = want_mac.hexdigest()
381382
382383 body = {
383384 "nonce": nonce,
384385 "username": "bob4",
385386 "displayname": "Bob's Name",
386387 "password": "abc123",
387 "mac": want_mac,
388 "mac": want_mac_str,
388389 }
389390 channel = self.make_request("POST", self.url, body)
390391
391 self.assertEqual(200, channel.code, msg=channel.json_body)
392 self.assertEqual(HTTPStatus.OK, channel.code, msg=channel.json_body)
392393 self.assertEqual("@bob4:test", channel.json_body["user_id"])
393394
394395 channel = self.make_request("GET", "/profile/@bob4:test/displayname")
395 self.assertEqual(200, channel.code, msg=channel.json_body)
396 self.assertEqual(HTTPStatus.OK, channel.code, msg=channel.json_body)
396397 self.assertEqual("Bob's Name", channel.json_body["displayname"])
397398
398399 @override_config(
424425 want_mac.update(
425426 nonce.encode("ascii") + b"\x00bob\x00abc123\x00admin\x00support"
426427 )
427 want_mac = want_mac.hexdigest()
428 want_mac_str = want_mac.hexdigest()
428429
429430 body = {
430431 "nonce": nonce,
432433 "password": "abc123",
433434 "admin": True,
434435 "user_type": UserTypes.SUPPORT,
435 "mac": want_mac,
436 "mac": want_mac_str,
436437 }
437438 channel = self.make_request("POST", self.url, body)
438439
439 self.assertEqual(200, channel.code, msg=channel.json_body)
440 self.assertEqual(HTTPStatus.OK, channel.code, msg=channel.json_body)
440441 self.assertEqual("@bob:test", channel.json_body["user_id"])
441442
442443
460461 """
461462 channel = self.make_request("GET", self.url, b"{}")
462463
463 self.assertEqual(401, channel.code, msg=channel.json_body)
464 self.assertEqual(HTTPStatus.UNAUTHORIZED, channel.code, msg=channel.json_body)
464465 self.assertEqual(Codes.MISSING_TOKEN, channel.json_body["errcode"])
465466
466467 def test_requester_is_no_admin(self):
472473
473474 channel = self.make_request("GET", self.url, access_token=other_user_token)
474475
475 self.assertEqual(403, channel.code, msg=channel.json_body)
476 self.assertEqual(HTTPStatus.FORBIDDEN, channel.code, msg=channel.json_body)
476477 self.assertEqual(Codes.FORBIDDEN, channel.json_body["errcode"])
477478
478479 def test_all_users(self):
488489 access_token=self.admin_user_tok,
489490 )
490491
491 self.assertEqual(200, channel.code, msg=channel.json_body)
492 self.assertEqual(HTTPStatus.OK, channel.code, msg=channel.json_body)
492493 self.assertEqual(3, len(channel.json_body["users"]))
493494 self.assertEqual(3, channel.json_body["total"])
494495
502503 expected_user_id: Optional[str],
503504 search_term: str,
504505 search_field: Optional[str] = "name",
505 expected_http_code: Optional[int] = 200,
506 expected_http_code: Optional[int] = HTTPStatus.OK,
506507 ):
507508 """Search for a user and check that the returned user's id is a match
508509
524525 )
525526 self.assertEqual(expected_http_code, channel.code, msg=channel.json_body)
526527
527 if expected_http_code != 200:
528 if expected_http_code != HTTPStatus.OK:
528529 return
529530
530531 # Check that users were returned
585586 access_token=self.admin_user_tok,
586587 )
587588
588 self.assertEqual(400, channel.code, msg=channel.json_body)
589 self.assertEqual(HTTPStatus.BAD_REQUEST, channel.code, msg=channel.json_body)
589590 self.assertEqual(Codes.INVALID_PARAM, channel.json_body["errcode"])
590591
591592 # negative from
595596 access_token=self.admin_user_tok,
596597 )
597598
598 self.assertEqual(400, channel.code, msg=channel.json_body)
599 self.assertEqual(HTTPStatus.BAD_REQUEST, channel.code, msg=channel.json_body)
599600 self.assertEqual(Codes.INVALID_PARAM, channel.json_body["errcode"])
600601
601602 # invalid guests
605606 access_token=self.admin_user_tok,
606607 )
607608
608 self.assertEqual(400, channel.code, msg=channel.json_body)
609 self.assertEqual(HTTPStatus.BAD_REQUEST, channel.code, msg=channel.json_body)
609610 self.assertEqual(Codes.UNKNOWN, channel.json_body["errcode"])
610611
611612 # invalid deactivated
615616 access_token=self.admin_user_tok,
616617 )
617618
618 self.assertEqual(400, channel.code, msg=channel.json_body)
619 self.assertEqual(HTTPStatus.BAD_REQUEST, channel.code, msg=channel.json_body)
619620 self.assertEqual(Codes.UNKNOWN, channel.json_body["errcode"])
620621
621622 # unkown order_by
625626 access_token=self.admin_user_tok,
626627 )
627628
628 self.assertEqual(400, channel.code, msg=channel.json_body)
629 self.assertEqual(HTTPStatus.BAD_REQUEST, channel.code, msg=channel.json_body)
629630 self.assertEqual(Codes.UNKNOWN, channel.json_body["errcode"])
630631
631632 # invalid search order
635636 access_token=self.admin_user_tok,
636637 )
637638
638 self.assertEqual(400, channel.code, msg=channel.json_body)
639 self.assertEqual(HTTPStatus.BAD_REQUEST, channel.code, msg=channel.json_body)
639640 self.assertEqual(Codes.UNKNOWN, channel.json_body["errcode"])
640641
641642 def test_limit(self):
653654 access_token=self.admin_user_tok,
654655 )
655656
656 self.assertEqual(200, channel.code, msg=channel.json_body)
657 self.assertEqual(HTTPStatus.OK, channel.code, msg=channel.json_body)
657658 self.assertEqual(channel.json_body["total"], number_users)
658659 self.assertEqual(len(channel.json_body["users"]), 5)
659660 self.assertEqual(channel.json_body["next_token"], "5")
674675 access_token=self.admin_user_tok,
675676 )
676677
677 self.assertEqual(200, channel.code, msg=channel.json_body)
678 self.assertEqual(HTTPStatus.OK, channel.code, msg=channel.json_body)
678679 self.assertEqual(channel.json_body["total"], number_users)
679680 self.assertEqual(len(channel.json_body["users"]), 15)
680681 self.assertNotIn("next_token", channel.json_body)
695696 access_token=self.admin_user_tok,
696697 )
697698
698 self.assertEqual(200, channel.code, msg=channel.json_body)
699 self.assertEqual(HTTPStatus.OK, channel.code, msg=channel.json_body)
699700 self.assertEqual(channel.json_body["total"], number_users)
700701 self.assertEqual(channel.json_body["next_token"], "15")
701702 self.assertEqual(len(channel.json_body["users"]), 10)
718719 access_token=self.admin_user_tok,
719720 )
720721
721 self.assertEqual(200, channel.code, msg=channel.json_body)
722 self.assertEqual(HTTPStatus.OK, channel.code, msg=channel.json_body)
722723 self.assertEqual(channel.json_body["total"], number_users)
723724 self.assertEqual(len(channel.json_body["users"]), number_users)
724725 self.assertNotIn("next_token", channel.json_body)
731732 access_token=self.admin_user_tok,
732733 )
733734
734 self.assertEqual(200, channel.code, msg=channel.json_body)
735 self.assertEqual(HTTPStatus.OK, channel.code, msg=channel.json_body)
735736 self.assertEqual(channel.json_body["total"], number_users)
736737 self.assertEqual(len(channel.json_body["users"]), number_users)
737738 self.assertNotIn("next_token", channel.json_body)
744745 access_token=self.admin_user_tok,
745746 )
746747
747 self.assertEqual(200, channel.code, msg=channel.json_body)
748 self.assertEqual(HTTPStatus.OK, channel.code, msg=channel.json_body)
748749 self.assertEqual(channel.json_body["total"], number_users)
749750 self.assertEqual(len(channel.json_body["users"]), 19)
750751 self.assertEqual(channel.json_body["next_token"], "19")
758759 access_token=self.admin_user_tok,
759760 )
760761
761 self.assertEqual(200, channel.code, msg=channel.json_body)
762 self.assertEqual(HTTPStatus.OK, channel.code, msg=channel.json_body)
762763 self.assertEqual(channel.json_body["total"], number_users)
763764 self.assertEqual(len(channel.json_body["users"]), 1)
764765 self.assertNotIn("next_token", channel.json_body)
861862 url,
862863 access_token=self.admin_user_tok,
863864 )
864 self.assertEqual(200, channel.code, msg=channel.json_body)
865 self.assertEqual(HTTPStatus.OK, channel.code, msg=channel.json_body)
865866 self.assertEqual(channel.json_body["total"], len(expected_user_list))
866867
867868 returned_order = [row["name"] for row in channel.json_body["users"]]
868869 self.assertEqual(expected_user_list, returned_order)
869870 self._check_fields(channel.json_body["users"])
870871
871 def _check_fields(self, content: JsonDict):
872 def _check_fields(self, content: List[JsonDict]):
872873 """Checks that the expected user attributes are present in content
873874 Args:
874875 content: List that is checked for content
935936 """
936937 channel = self.make_request("POST", self.url, b"{}")
937938
938 self.assertEqual(401, channel.code, msg=channel.json_body)
939 self.assertEqual(HTTPStatus.UNAUTHORIZED, channel.code, msg=channel.json_body)
939940 self.assertEqual(Codes.MISSING_TOKEN, channel.json_body["errcode"])
940941
941942 def test_requester_is_not_admin(self):
946947
947948 channel = self.make_request("POST", url, access_token=self.other_user_token)
948949
949 self.assertEqual(403, channel.code, msg=channel.json_body)
950 self.assertEqual(HTTPStatus.FORBIDDEN, channel.code, msg=channel.json_body)
950951 self.assertEqual("You are not a server admin", channel.json_body["error"])
951952
952953 channel = self.make_request(
956957 content=b"{}",
957958 )
958959
959 self.assertEqual(403, channel.code, msg=channel.json_body)
960 self.assertEqual(HTTPStatus.FORBIDDEN, channel.code, msg=channel.json_body)
960961 self.assertEqual("You are not a server admin", channel.json_body["error"])
961962
962963 def test_user_does_not_exist(self):
963964 """
964 Tests that deactivation for a user that does not exist returns a 404
965 Tests that deactivation for a user that does not exist returns a HTTPStatus.NOT_FOUND
965966 """
966967
967968 channel = self.make_request(
970971 access_token=self.admin_user_tok,
971972 )
972973
973 self.assertEqual(404, channel.code, msg=channel.json_body)
974 self.assertEqual(HTTPStatus.NOT_FOUND, channel.code, msg=channel.json_body)
974975 self.assertEqual(Codes.NOT_FOUND, channel.json_body["errcode"])
975976
976977 def test_erase_is_not_bool(self):
985986 access_token=self.admin_user_tok,
986987 )
987988
988 self.assertEqual(400, channel.code, msg=channel.json_body)
989 self.assertEqual(HTTPStatus.BAD_REQUEST, channel.code, msg=channel.json_body)
989990 self.assertEqual(Codes.BAD_JSON, channel.json_body["errcode"])
990991
991992 def test_user_is_not_local(self):
992993 """
993 Tests that deactivation for a user that is not a local returns a 400
994 Tests that deactivation for a user that is not a local returns a HTTPStatus.BAD_REQUEST
994995 """
995996 url = "/_synapse/admin/v1/deactivate/@unknown_person:unknown_domain"
996997
997998 channel = self.make_request("POST", url, access_token=self.admin_user_tok)
998999
999 self.assertEqual(400, channel.code, msg=channel.json_body)
1000 self.assertEqual(HTTPStatus.BAD_REQUEST, channel.code, msg=channel.json_body)
10001001 self.assertEqual("Can only deactivate local users", channel.json_body["error"])
10011002
10021003 def test_deactivate_user_erase_true(self):
10111012 access_token=self.admin_user_tok,
10121013 )
10131014
1014 self.assertEqual(200, channel.code, msg=channel.json_body)
1015 self.assertEqual(HTTPStatus.OK, channel.code, msg=channel.json_body)
10151016 self.assertEqual("@user:test", channel.json_body["name"])
10161017 self.assertEqual(False, channel.json_body["deactivated"])
10171018 self.assertEqual("foo@bar.com", channel.json_body["threepids"][0]["address"])
10261027 content={"erase": True},
10271028 )
10281029
1029 self.assertEqual(200, channel.code, msg=channel.json_body)
1030 self.assertEqual(HTTPStatus.OK, channel.code, msg=channel.json_body)
10301031
10311032 # Get user
10321033 channel = self.make_request(
10351036 access_token=self.admin_user_tok,
10361037 )
10371038
1038 self.assertEqual(200, channel.code, msg=channel.json_body)
1039 self.assertEqual(HTTPStatus.OK, channel.code, msg=channel.json_body)
10391040 self.assertEqual("@user:test", channel.json_body["name"])
10401041 self.assertEqual(True, channel.json_body["deactivated"])
10411042 self.assertEqual(0, len(channel.json_body["threepids"]))
10561057 access_token=self.admin_user_tok,
10571058 )
10581059
1059 self.assertEqual(200, channel.code, msg=channel.json_body)
1060 self.assertEqual(HTTPStatus.OK, channel.code, msg=channel.json_body)
10601061 self.assertEqual("@user:test", channel.json_body["name"])
10611062 self.assertEqual(False, channel.json_body["deactivated"])
10621063 self.assertEqual("foo@bar.com", channel.json_body["threepids"][0]["address"])
10711072 content={"erase": False},
10721073 )
10731074
1074 self.assertEqual(200, channel.code, msg=channel.json_body)
1075 self.assertEqual(HTTPStatus.OK, channel.code, msg=channel.json_body)
10751076
10761077 # Get user
10771078 channel = self.make_request(
10801081 access_token=self.admin_user_tok,
10811082 )
10821083
1083 self.assertEqual(200, channel.code, msg=channel.json_body)
1084 self.assertEqual(HTTPStatus.OK, channel.code, msg=channel.json_body)
10841085 self.assertEqual("@user:test", channel.json_body["name"])
10851086 self.assertEqual(True, channel.json_body["deactivated"])
10861087 self.assertEqual(0, len(channel.json_body["threepids"]))
11101111 access_token=self.admin_user_tok,
11111112 )
11121113
1113 self.assertEqual(200, channel.code, msg=channel.json_body)
1114 self.assertEqual(HTTPStatus.OK, channel.code, msg=channel.json_body)
11141115 self.assertEqual("@user:test", channel.json_body["name"])
11151116 self.assertEqual(False, channel.json_body["deactivated"])
11161117 self.assertEqual("foo@bar.com", channel.json_body["threepids"][0]["address"])
11251126 content={"erase": True},
11261127 )
11271128
1128 self.assertEqual(200, channel.code, msg=channel.json_body)
1129 self.assertEqual(HTTPStatus.OK, channel.code, msg=channel.json_body)
11291130
11301131 # Get user
11311132 channel = self.make_request(
11341135 access_token=self.admin_user_tok,
11351136 )
11361137
1137 self.assertEqual(200, channel.code, msg=channel.json_body)
1138 self.assertEqual(HTTPStatus.OK, channel.code, msg=channel.json_body)
11381139 self.assertEqual("@user:test", channel.json_body["name"])
11391140 self.assertEqual(True, channel.json_body["deactivated"])
11401141 self.assertEqual(0, len(channel.json_body["threepids"]))
11941195 access_token=self.other_user_token,
11951196 )
11961197
1197 self.assertEqual(403, channel.code, msg=channel.json_body)
1198 self.assertEqual(HTTPStatus.FORBIDDEN, channel.code, msg=channel.json_body)
11981199 self.assertEqual("You are not a server admin", channel.json_body["error"])
11991200
12001201 channel = self.make_request(
12041205 content=b"{}",
12051206 )
12061207
1207 self.assertEqual(403, channel.code, msg=channel.json_body)
1208 self.assertEqual(HTTPStatus.FORBIDDEN, channel.code, msg=channel.json_body)
12081209 self.assertEqual("You are not a server admin", channel.json_body["error"])
12091210
12101211 def test_user_does_not_exist(self):
12111212 """
1212 Tests that a lookup for a user that does not exist returns a 404
1213 Tests that a lookup for a user that does not exist returns a HTTPStatus.NOT_FOUND
12131214 """
12141215
12151216 channel = self.make_request(
12181219 access_token=self.admin_user_tok,
12191220 )
12201221
1221 self.assertEqual(404, channel.code, msg=channel.json_body)
1222 self.assertEqual(HTTPStatus.NOT_FOUND, channel.code, msg=channel.json_body)
12221223 self.assertEqual("M_NOT_FOUND", channel.json_body["errcode"])
12231224
12241225 def test_invalid_parameter(self):
12331234 access_token=self.admin_user_tok,
12341235 content={"admin": "not_bool"},
12351236 )
1236 self.assertEqual(400, channel.code, msg=channel.json_body)
1237 self.assertEqual(HTTPStatus.BAD_REQUEST, channel.code, msg=channel.json_body)
12371238 self.assertEqual(Codes.BAD_JSON, channel.json_body["errcode"])
12381239
12391240 # deactivated not bool
12431244 access_token=self.admin_user_tok,
12441245 content={"deactivated": "not_bool"},
12451246 )
1246 self.assertEqual(400, channel.code, msg=channel.json_body)
1247 self.assertEqual(HTTPStatus.BAD_REQUEST, channel.code, msg=channel.json_body)
12471248 self.assertEqual(Codes.UNKNOWN, channel.json_body["errcode"])
12481249
12491250 # password not str
12531254 access_token=self.admin_user_tok,
12541255 content={"password": True},
12551256 )
1256 self.assertEqual(400, channel.code, msg=channel.json_body)
1257 self.assertEqual(HTTPStatus.BAD_REQUEST, channel.code, msg=channel.json_body)
12571258 self.assertEqual(Codes.UNKNOWN, channel.json_body["errcode"])
12581259
12591260 # password not length
12631264 access_token=self.admin_user_tok,
12641265 content={"password": "x" * 513},
12651266 )
1266 self.assertEqual(400, channel.code, msg=channel.json_body)
1267 self.assertEqual(HTTPStatus.BAD_REQUEST, channel.code, msg=channel.json_body)
12671268 self.assertEqual(Codes.UNKNOWN, channel.json_body["errcode"])
12681269
12691270 # user_type not valid
12731274 access_token=self.admin_user_tok,
12741275 content={"user_type": "new type"},
12751276 )
1276 self.assertEqual(400, channel.code, msg=channel.json_body)
1277 self.assertEqual(HTTPStatus.BAD_REQUEST, channel.code, msg=channel.json_body)
12771278 self.assertEqual(Codes.UNKNOWN, channel.json_body["errcode"])
12781279
12791280 # external_ids not valid
12851286 "external_ids": {"auth_provider": "prov", "wrong_external_id": "id"}
12861287 },
12871288 )
1288 self.assertEqual(400, channel.code, msg=channel.json_body)
1289 self.assertEqual(HTTPStatus.BAD_REQUEST, channel.code, msg=channel.json_body)
12891290 self.assertEqual(Codes.MISSING_PARAM, channel.json_body["errcode"])
12901291
12911292 channel = self.make_request(
12941295 access_token=self.admin_user_tok,
12951296 content={"external_ids": {"external_id": "id"}},
12961297 )
1297 self.assertEqual(400, channel.code, msg=channel.json_body)
1298 self.assertEqual(HTTPStatus.BAD_REQUEST, channel.code, msg=channel.json_body)
12981299 self.assertEqual(Codes.MISSING_PARAM, channel.json_body["errcode"])
12991300
13001301 # threepids not valid
13041305 access_token=self.admin_user_tok,
13051306 content={"threepids": {"medium": "email", "wrong_address": "id"}},
13061307 )
1307 self.assertEqual(400, channel.code, msg=channel.json_body)
1308 self.assertEqual(HTTPStatus.BAD_REQUEST, channel.code, msg=channel.json_body)
13081309 self.assertEqual(Codes.MISSING_PARAM, channel.json_body["errcode"])
13091310
13101311 channel = self.make_request(
13131314 access_token=self.admin_user_tok,
13141315 content={"threepids": {"address": "value"}},
13151316 )
1316 self.assertEqual(400, channel.code, msg=channel.json_body)
1317 self.assertEqual(HTTPStatus.BAD_REQUEST, channel.code, msg=channel.json_body)
13171318 self.assertEqual(Codes.MISSING_PARAM, channel.json_body["errcode"])
13181319
13191320 def test_get_user(self):
13261327 access_token=self.admin_user_tok,
13271328 )
13281329
1329 self.assertEqual(200, channel.code, msg=channel.json_body)
1330 self.assertEqual(HTTPStatus.OK, channel.code, msg=channel.json_body)
13301331 self.assertEqual("@user:test", channel.json_body["name"])
13311332 self.assertEqual("User", channel.json_body["displayname"])
13321333 self._check_fields(channel.json_body)
13691370 access_token=self.admin_user_tok,
13701371 )
13711372
1372 self.assertEqual(200, channel.code, msg=channel.json_body)
1373 self.assertEqual(HTTPStatus.OK, channel.code, msg=channel.json_body)
13731374 self.assertEqual("@bob:test", channel.json_body["name"])
13741375 self.assertEqual("Bob's name", channel.json_body["displayname"])
13751376 self.assertEqual("email", channel.json_body["threepids"][0]["medium"])
14321433 access_token=self.admin_user_tok,
14331434 )
14341435
1435 self.assertEqual(200, channel.code, msg=channel.json_body)
1436 self.assertEqual(HTTPStatus.OK, channel.code, msg=channel.json_body)
14361437 self.assertEqual("@bob:test", channel.json_body["name"])
14371438 self.assertEqual("Bob's name", channel.json_body["displayname"])
14381439 self.assertEqual("email", channel.json_body["threepids"][0]["medium"])
14601461 # before limit of monthly active users is reached
14611462 channel = self.make_request("GET", "/sync", access_token=self.admin_user_tok)
14621463
1463 if channel.code != 200:
1464 if channel.code != HTTPStatus.OK:
14641465 raise HttpResponseException(
1465 channel.code, channel.result["reason"], channel.result["body"]
1466 channel.code, channel.result["reason"], channel.json_body
14661467 )
14671468
14681469 # Set monthly active users to the limit
16241625 content={"password": "hahaha"},
16251626 )
16261627
1627 self.assertEqual(200, channel.code, msg=channel.json_body)
1628 self.assertEqual(HTTPStatus.OK, channel.code, msg=channel.json_body)
16281629 self._check_fields(channel.json_body)
16291630
16301631 def test_set_displayname(self):
16401641 content={"displayname": "foobar"},
16411642 )
16421643
1643 self.assertEqual(200, channel.code, msg=channel.json_body)
1644 self.assertEqual(HTTPStatus.OK, channel.code, msg=channel.json_body)
16441645 self.assertEqual("@user:test", channel.json_body["name"])
16451646 self.assertEqual("foobar", channel.json_body["displayname"])
16461647
16511652 access_token=self.admin_user_tok,
16521653 )
16531654
1654 self.assertEqual(200, channel.code, msg=channel.json_body)
1655 self.assertEqual(HTTPStatus.OK, channel.code, msg=channel.json_body)
16551656 self.assertEqual("@user:test", channel.json_body["name"])
16561657 self.assertEqual("foobar", channel.json_body["displayname"])
16571658
16731674 },
16741675 )
16751676
1676 self.assertEqual(200, channel.code, msg=channel.json_body)
1677 self.assertEqual(HTTPStatus.OK, channel.code, msg=channel.json_body)
16771678 self.assertEqual("@user:test", channel.json_body["name"])
16781679 self.assertEqual(2, len(channel.json_body["threepids"]))
16791680 # result does not always have the same sort order, therefore it becomes sorted
16991700 },
17001701 )
17011702
1702 self.assertEqual(200, channel.code, msg=channel.json_body)
1703 self.assertEqual(HTTPStatus.OK, channel.code, msg=channel.json_body)
17031704 self.assertEqual("@user:test", channel.json_body["name"])
17041705 self.assertEqual(2, len(channel.json_body["threepids"]))
17051706 self.assertEqual("email", channel.json_body["threepids"][0]["medium"])
17151716 access_token=self.admin_user_tok,
17161717 )
17171718
1718 self.assertEqual(200, channel.code, msg=channel.json_body)
1719 self.assertEqual(HTTPStatus.OK, channel.code, msg=channel.json_body)
17191720 self.assertEqual("@user:test", channel.json_body["name"])
17201721 self.assertEqual(2, len(channel.json_body["threepids"]))
17211722 self.assertEqual("email", channel.json_body["threepids"][0]["medium"])
17311732 access_token=self.admin_user_tok,
17321733 content={"threepids": []},
17331734 )
1734 self.assertEqual(200, channel.code, msg=channel.json_body)
1735 self.assertEqual(HTTPStatus.OK, channel.code, msg=channel.json_body)
17351736 self.assertEqual("@user:test", channel.json_body["name"])
17361737 self.assertEqual(0, len(channel.json_body["threepids"]))
17371738 self._check_fields(channel.json_body)
17581759 },
17591760 )
17601761
1761 self.assertEqual(200, channel.code, msg=channel.json_body)
1762 self.assertEqual(HTTPStatus.OK, channel.code, msg=channel.json_body)
17621763 self.assertEqual(first_user, channel.json_body["name"])
17631764 self.assertEqual(1, len(channel.json_body["threepids"]))
17641765 self.assertEqual("email", channel.json_body["threepids"][0]["medium"])
17771778 },
17781779 )
17791780
1780 self.assertEqual(200, channel.code, msg=channel.json_body)
1781 self.assertEqual(HTTPStatus.OK, channel.code, msg=channel.json_body)
17811782 self.assertEqual("@user:test", channel.json_body["name"])
17821783 self.assertEqual(1, len(channel.json_body["threepids"]))
17831784 self.assertEqual("email", channel.json_body["threepids"][0]["medium"])
17991800 )
18001801
18011802 # other user has this two threepids
1802 self.assertEqual(200, channel.code, msg=channel.json_body)
1803 self.assertEqual(HTTPStatus.OK, channel.code, msg=channel.json_body)
18031804 self.assertEqual("@user:test", channel.json_body["name"])
18041805 self.assertEqual(2, len(channel.json_body["threepids"]))
18051806 # result does not always have the same sort order, therefore it becomes sorted
18181819 url_first_user,
18191820 access_token=self.admin_user_tok,
18201821 )
1821 self.assertEqual(200, channel.code, msg=channel.json_body)
1822 self.assertEqual(HTTPStatus.OK, channel.code, msg=channel.json_body)
18221823 self.assertEqual(first_user, channel.json_body["name"])
18231824 self.assertEqual(0, len(channel.json_body["threepids"]))
18241825 self._check_fields(channel.json_body)
18471848 },
18481849 )
18491850
1850 self.assertEqual(200, channel.code, msg=channel.json_body)
1851 self.assertEqual(HTTPStatus.OK, channel.code, msg=channel.json_body)
18511852 self.assertEqual("@user:test", channel.json_body["name"])
18521853 self.assertEqual(2, len(channel.json_body["external_ids"]))
18531854 # result does not always have the same sort order, therefore it becomes sorted
18791880 },
18801881 )
18811882
1882 self.assertEqual(200, channel.code, msg=channel.json_body)
1883 self.assertEqual(HTTPStatus.OK, channel.code, msg=channel.json_body)
18831884 self.assertEqual("@user:test", channel.json_body["name"])
18841885 self.assertEqual(2, len(channel.json_body["external_ids"]))
18851886 self.assertEqual(
18981899 access_token=self.admin_user_tok,
18991900 )
19001901
1901 self.assertEqual(200, channel.code, msg=channel.json_body)
1902 self.assertEqual(HTTPStatus.OK, channel.code, msg=channel.json_body)
19021903 self.assertEqual("@user:test", channel.json_body["name"])
19031904 self.assertEqual(2, len(channel.json_body["external_ids"]))
19041905 self.assertEqual(
19171918 access_token=self.admin_user_tok,
19181919 content={"external_ids": []},
19191920 )
1920 self.assertEqual(200, channel.code, msg=channel.json_body)
1921 self.assertEqual(HTTPStatus.OK, channel.code, msg=channel.json_body)
19211922 self.assertEqual("@user:test", channel.json_body["name"])
19221923 self.assertEqual(0, len(channel.json_body["external_ids"]))
19231924
19461947 },
19471948 )
19481949
1949 self.assertEqual(200, channel.code, msg=channel.json_body)
1950 self.assertEqual(HTTPStatus.OK, channel.code, msg=channel.json_body)
19501951 self.assertEqual(first_user, channel.json_body["name"])
19511952 self.assertEqual(1, len(channel.json_body["external_ids"]))
19521953 self.assertEqual(
19721973 },
19731974 )
19741975
1975 self.assertEqual(200, channel.code, msg=channel.json_body)
1976 self.assertEqual(HTTPStatus.OK, channel.code, msg=channel.json_body)
19761977 self.assertEqual("@user:test", channel.json_body["name"])
19771978 self.assertEqual(1, len(channel.json_body["external_ids"]))
19781979 self.assertEqual(
20042005 )
20052006
20062007 # must fail
2007 self.assertEqual(409, channel.code, msg=channel.json_body)
2008 self.assertEqual(HTTPStatus.CONFLICT, channel.code, msg=channel.json_body)
20082009 self.assertEqual(Codes.UNKNOWN, channel.json_body["errcode"])
20092010 self.assertEqual("External id is already in use.", channel.json_body["error"])
20102011
20152016 access_token=self.admin_user_tok,
20162017 )
20172018
2018 self.assertEqual(200, channel.code, msg=channel.json_body)
2019 self.assertEqual(HTTPStatus.OK, channel.code, msg=channel.json_body)
20192020 self.assertEqual("@user:test", channel.json_body["name"])
20202021 self.assertEqual(1, len(channel.json_body["external_ids"]))
20212022 self.assertEqual(
20332034 access_token=self.admin_user_tok,
20342035 )
20352036
2036 self.assertEqual(200, channel.code, msg=channel.json_body)
2037 self.assertEqual(HTTPStatus.OK, channel.code, msg=channel.json_body)
20372038 self.assertEqual(first_user, channel.json_body["name"])
20382039 self.assertEqual(1, len(channel.json_body["external_ids"]))
20392040 self.assertEqual(
20642065 access_token=self.admin_user_tok,
20652066 )
20662067
2067 self.assertEqual(200, channel.code, msg=channel.json_body)
2068 self.assertEqual(HTTPStatus.OK, channel.code, msg=channel.json_body)
20682069 self.assertEqual("@user:test", channel.json_body["name"])
20692070 self.assertFalse(channel.json_body["deactivated"])
20702071 self.assertEqual("foo@bar.com", channel.json_body["threepids"][0]["address"])
20792080 content={"deactivated": True},
20802081 )
20812082
2082 self.assertEqual(200, channel.code, msg=channel.json_body)
2083 self.assertEqual(HTTPStatus.OK, channel.code, msg=channel.json_body)
20832084 self.assertEqual("@user:test", channel.json_body["name"])
20842085 self.assertTrue(channel.json_body["deactivated"])
20852086 self.assertIsNone(channel.json_body["password_hash"])
20952096 access_token=self.admin_user_tok,
20962097 )
20972098
2098 self.assertEqual(200, channel.code, msg=channel.json_body)
2099 self.assertEqual(HTTPStatus.OK, channel.code, msg=channel.json_body)
20992100 self.assertEqual("@user:test", channel.json_body["name"])
21002101 self.assertTrue(channel.json_body["deactivated"])
21012102 self.assertIsNone(channel.json_body["password_hash"])
21222123 content={"deactivated": True},
21232124 )
21242125
2125 self.assertEqual(200, channel.code, msg=channel.json_body)
2126 self.assertEqual(HTTPStatus.OK, channel.code, msg=channel.json_body)
21262127 self.assertEqual("@user:test", channel.json_body["name"])
21272128 self.assertTrue(channel.json_body["deactivated"])
21282129
21382139 content={"displayname": "Foobar"},
21392140 )
21402141
2141 self.assertEqual(200, channel.code, msg=channel.json_body)
2142 self.assertEqual(HTTPStatus.OK, channel.code, msg=channel.json_body)
21422143 self.assertEqual("@user:test", channel.json_body["name"])
21432144 self.assertTrue(channel.json_body["deactivated"])
21442145 self.assertEqual("Foobar", channel.json_body["displayname"])
21622163 access_token=self.admin_user_tok,
21632164 content={"deactivated": False},
21642165 )
2165 self.assertEqual(400, channel.code, msg=channel.json_body)
2166 self.assertEqual(HTTPStatus.BAD_REQUEST, channel.code, msg=channel.json_body)
21662167
21672168 # Reactivate the user.
21682169 channel = self.make_request(
21712172 access_token=self.admin_user_tok,
21722173 content={"deactivated": False, "password": "foo"},
21732174 )
2174 self.assertEqual(200, channel.code, msg=channel.json_body)
2175 self.assertEqual(HTTPStatus.OK, channel.code, msg=channel.json_body)
21752176 self.assertEqual("@user:test", channel.json_body["name"])
21762177 self.assertFalse(channel.json_body["deactivated"])
21772178 self.assertIsNotNone(channel.json_body["password_hash"])
21932194 access_token=self.admin_user_tok,
21942195 content={"deactivated": False, "password": "foo"},
21952196 )
2196 self.assertEqual(403, channel.code, msg=channel.json_body)
2197 self.assertEqual(HTTPStatus.FORBIDDEN, channel.code, msg=channel.json_body)
21972198 self.assertEqual(Codes.FORBIDDEN, channel.json_body["errcode"])
21982199
21992200 # Reactivate the user without a password.
22032204 access_token=self.admin_user_tok,
22042205 content={"deactivated": False},
22052206 )
2206 self.assertEqual(200, channel.code, msg=channel.json_body)
2207 self.assertEqual(HTTPStatus.OK, channel.code, msg=channel.json_body)
22072208 self.assertEqual("@user:test", channel.json_body["name"])
22082209 self.assertFalse(channel.json_body["deactivated"])
22092210 self.assertIsNone(channel.json_body["password_hash"])
22252226 access_token=self.admin_user_tok,
22262227 content={"deactivated": False, "password": "foo"},
22272228 )
2228 self.assertEqual(403, channel.code, msg=channel.json_body)
2229 self.assertEqual(HTTPStatus.FORBIDDEN, channel.code, msg=channel.json_body)
22292230 self.assertEqual(Codes.FORBIDDEN, channel.json_body["errcode"])
22302231
22312232 # Reactivate the user without a password.
22352236 access_token=self.admin_user_tok,
22362237 content={"deactivated": False},
22372238 )
2238 self.assertEqual(200, channel.code, msg=channel.json_body)
2239 self.assertEqual(HTTPStatus.OK, channel.code, msg=channel.json_body)
22392240 self.assertEqual("@user:test", channel.json_body["name"])
22402241 self.assertFalse(channel.json_body["deactivated"])
22412242 self.assertIsNone(channel.json_body["password_hash"])
22542255 content={"admin": True},
22552256 )
22562257
2257 self.assertEqual(200, channel.code, msg=channel.json_body)
2258 self.assertEqual(HTTPStatus.OK, channel.code, msg=channel.json_body)
22582259 self.assertEqual("@user:test", channel.json_body["name"])
22592260 self.assertTrue(channel.json_body["admin"])
22602261
22652266 access_token=self.admin_user_tok,
22662267 )
22672268
2268 self.assertEqual(200, channel.code, msg=channel.json_body)
2269 self.assertEqual(HTTPStatus.OK, channel.code, msg=channel.json_body)
22692270 self.assertEqual("@user:test", channel.json_body["name"])
22702271 self.assertTrue(channel.json_body["admin"])
22712272
22822283 content={"user_type": UserTypes.SUPPORT},
22832284 )
22842285
2285 self.assertEqual(200, channel.code, msg=channel.json_body)
2286 self.assertEqual(HTTPStatus.OK, channel.code, msg=channel.json_body)
22862287 self.assertEqual("@user:test", channel.json_body["name"])
22872288 self.assertEqual(UserTypes.SUPPORT, channel.json_body["user_type"])
22882289
22932294 access_token=self.admin_user_tok,
22942295 )
22952296
2296 self.assertEqual(200, channel.code, msg=channel.json_body)
2297 self.assertEqual(HTTPStatus.OK, channel.code, msg=channel.json_body)
22972298 self.assertEqual("@user:test", channel.json_body["name"])
22982299 self.assertEqual(UserTypes.SUPPORT, channel.json_body["user_type"])
22992300
23052306 content={"user_type": None},
23062307 )
23072308
2308 self.assertEqual(200, channel.code, msg=channel.json_body)
2309 self.assertEqual(HTTPStatus.OK, channel.code, msg=channel.json_body)
23092310 self.assertEqual("@user:test", channel.json_body["name"])
23102311 self.assertIsNone(channel.json_body["user_type"])
23112312
23162317 access_token=self.admin_user_tok,
23172318 )
23182319
2319 self.assertEqual(200, channel.code, msg=channel.json_body)
2320 self.assertEqual(HTTPStatus.OK, channel.code, msg=channel.json_body)
23202321 self.assertEqual("@user:test", channel.json_body["name"])
23212322 self.assertIsNone(channel.json_body["user_type"])
23222323
23462347 access_token=self.admin_user_tok,
23472348 )
23482349
2349 self.assertEqual(200, channel.code, msg=channel.json_body)
2350 self.assertEqual(HTTPStatus.OK, channel.code, msg=channel.json_body)
23502351 self.assertEqual("@bob:test", channel.json_body["name"])
23512352 self.assertEqual("bob", channel.json_body["displayname"])
23522353 self.assertEqual(0, channel.json_body["deactivated"])
23592360 content={"password": "abc123", "deactivated": "false"},
23602361 )
23612362
2362 self.assertEqual(400, channel.code, msg=channel.json_body)
2363 self.assertEqual(HTTPStatus.BAD_REQUEST, channel.code, msg=channel.json_body)
23632364
23642365 # Check user is not deactivated
23652366 channel = self.make_request(
23682369 access_token=self.admin_user_tok,
23692370 )
23702371
2371 self.assertEqual(200, channel.code, msg=channel.json_body)
2372 self.assertEqual(HTTPStatus.OK, channel.code, msg=channel.json_body)
23722373 self.assertEqual("@bob:test", channel.json_body["name"])
23732374 self.assertEqual("bob", channel.json_body["displayname"])
23742375
23932394 access_token=self.admin_user_tok,
23942395 content={"deactivated": True},
23952396 )
2396 self.assertEqual(200, channel.code, msg=channel.json_body)
2397 self.assertEqual(HTTPStatus.OK, channel.code, msg=channel.json_body)
23972398 self.assertTrue(channel.json_body["deactivated"])
23982399 self.assertIsNone(channel.json_body["password_hash"])
23992400 self._is_erased(user_id, False)
24442445 """
24452446 channel = self.make_request("GET", self.url, b"{}")
24462447
2447 self.assertEqual(401, channel.code, msg=channel.json_body)
2448 self.assertEqual(HTTPStatus.UNAUTHORIZED, channel.code, msg=channel.json_body)
24482449 self.assertEqual(Codes.MISSING_TOKEN, channel.json_body["errcode"])
24492450
24502451 def test_requester_is_no_admin(self):
24592460 access_token=other_user_token,
24602461 )
24612462
2462 self.assertEqual(403, channel.code, msg=channel.json_body)
2463 self.assertEqual(HTTPStatus.FORBIDDEN, channel.code, msg=channel.json_body)
24632464 self.assertEqual(Codes.FORBIDDEN, channel.json_body["errcode"])
24642465
24652466 def test_user_does_not_exist(self):
24732474 access_token=self.admin_user_tok,
24742475 )
24752476
2476 self.assertEqual(200, channel.code, msg=channel.json_body)
2477 self.assertEqual(HTTPStatus.OK, channel.code, msg=channel.json_body)
24772478 self.assertEqual(0, channel.json_body["total"])
24782479 self.assertEqual(0, len(channel.json_body["joined_rooms"]))
24792480
24892490 access_token=self.admin_user_tok,
24902491 )
24912492
2492 self.assertEqual(200, channel.code, msg=channel.json_body)
2493 self.assertEqual(HTTPStatus.OK, channel.code, msg=channel.json_body)
24932494 self.assertEqual(0, channel.json_body["total"])
24942495 self.assertEqual(0, len(channel.json_body["joined_rooms"]))
24952496
25052506 access_token=self.admin_user_tok,
25062507 )
25072508
2508 self.assertEqual(200, channel.code, msg=channel.json_body)
2509 self.assertEqual(HTTPStatus.OK, channel.code, msg=channel.json_body)
25092510 self.assertEqual(0, channel.json_body["total"])
25102511 self.assertEqual(0, len(channel.json_body["joined_rooms"]))
25112512
25262527 access_token=self.admin_user_tok,
25272528 )
25282529
2529 self.assertEqual(200, channel.code, msg=channel.json_body)
2530 self.assertEqual(HTTPStatus.OK, channel.code, msg=channel.json_body)
25302531 self.assertEqual(number_rooms, channel.json_body["total"])
25312532 self.assertEqual(number_rooms, len(channel.json_body["joined_rooms"]))
25322533
25732574 access_token=self.admin_user_tok,
25742575 )
25752576
2576 self.assertEqual(200, channel.code, msg=channel.json_body)
2577 self.assertEqual(HTTPStatus.OK, channel.code, msg=channel.json_body)
25772578 self.assertEqual(1, channel.json_body["total"])
25782579 self.assertEqual([local_and_remote_room_id], channel.json_body["joined_rooms"])
25792580
26022603 """
26032604 channel = self.make_request("GET", self.url, b"{}")
26042605
2605 self.assertEqual(401, channel.code, msg=channel.json_body)
2606 self.assertEqual(HTTPStatus.UNAUTHORIZED, channel.code, msg=channel.json_body)
26062607 self.assertEqual(Codes.MISSING_TOKEN, channel.json_body["errcode"])
26072608
26082609 def test_requester_is_no_admin(self):
26172618 access_token=other_user_token,
26182619 )
26192620
2620 self.assertEqual(403, channel.code, msg=channel.json_body)
2621 self.assertEqual(HTTPStatus.FORBIDDEN, channel.code, msg=channel.json_body)
26212622 self.assertEqual(Codes.FORBIDDEN, channel.json_body["errcode"])
26222623
26232624 def test_user_does_not_exist(self):
26242625 """
2625 Tests that a lookup for a user that does not exist returns a 404
2626 Tests that a lookup for a user that does not exist returns a HTTPStatus.NOT_FOUND
26262627 """
26272628 url = "/_synapse/admin/v1/users/@unknown_person:test/pushers"
26282629 channel = self.make_request(
26312632 access_token=self.admin_user_tok,
26322633 )
26332634
2634 self.assertEqual(404, channel.code, msg=channel.json_body)
2635 self.assertEqual(HTTPStatus.NOT_FOUND, channel.code, msg=channel.json_body)
26352636 self.assertEqual(Codes.NOT_FOUND, channel.json_body["errcode"])
26362637
26372638 def test_user_is_not_local(self):
26382639 """
2639 Tests that a lookup for a user that is not a local returns a 400
2640 Tests that a lookup for a user that is not a local returns a HTTPStatus.BAD_REQUEST
26402641 """
26412642 url = "/_synapse/admin/v1/users/@unknown_person:unknown_domain/pushers"
26422643
26462647 access_token=self.admin_user_tok,
26472648 )
26482649
2649 self.assertEqual(400, channel.code, msg=channel.json_body)
2650 self.assertEqual(HTTPStatus.BAD_REQUEST, channel.code, msg=channel.json_body)
26502651 self.assertEqual("Can only look up local users", channel.json_body["error"])
26512652
26522653 def test_get_pushers(self):
26612662 access_token=self.admin_user_tok,
26622663 )
26632664
2664 self.assertEqual(200, channel.code, msg=channel.json_body)
2665 self.assertEqual(HTTPStatus.OK, channel.code, msg=channel.json_body)
26652666 self.assertEqual(0, channel.json_body["total"])
26662667
26672668 # Register the pusher
26922693 access_token=self.admin_user_tok,
26932694 )
26942695
2695 self.assertEqual(200, channel.code, msg=channel.json_body)
2696 self.assertEqual(HTTPStatus.OK, channel.code, msg=channel.json_body)
26962697 self.assertEqual(1, channel.json_body["total"])
26972698
26982699 for p in channel.json_body["pushers"]:
27312732 """Try to list media of an user without authentication."""
27322733 channel = self.make_request(method, self.url, {})
27332734
2734 self.assertEqual(401, channel.code, msg=channel.json_body)
2735 self.assertEqual(HTTPStatus.UNAUTHORIZED, channel.code, msg=channel.json_body)
27352736 self.assertEqual(Codes.MISSING_TOKEN, channel.json_body["errcode"])
27362737
27372738 @parameterized.expand(["GET", "DELETE"])
27452746 access_token=other_user_token,
27462747 )
27472748
2748 self.assertEqual(403, channel.code, msg=channel.json_body)
2749 self.assertEqual(HTTPStatus.FORBIDDEN, channel.code, msg=channel.json_body)
27492750 self.assertEqual(Codes.FORBIDDEN, channel.json_body["errcode"])
27502751
27512752 @parameterized.expand(["GET", "DELETE"])
27522753 def test_user_does_not_exist(self, method: str):
2753 """Tests that a lookup for a user that does not exist returns a 404"""
2754 """Tests that a lookup for a user that does not exist returns a HTTPStatus.NOT_FOUND"""
27542755 url = "/_synapse/admin/v1/users/@unknown_person:test/media"
27552756 channel = self.make_request(
27562757 method,
27582759 access_token=self.admin_user_tok,
27592760 )
27602761
2761 self.assertEqual(404, channel.code, msg=channel.json_body)
2762 self.assertEqual(HTTPStatus.NOT_FOUND, channel.code, msg=channel.json_body)
27622763 self.assertEqual(Codes.NOT_FOUND, channel.json_body["errcode"])
27632764
27642765 @parameterized.expand(["GET", "DELETE"])
27652766 def test_user_is_not_local(self, method: str):
2766 """Tests that a lookup for a user that is not a local returns a 400"""
2767 """Tests that a lookup for a user that is not a local returns a HTTPStatus.BAD_REQUEST"""
27672768 url = "/_synapse/admin/v1/users/@unknown_person:unknown_domain/media"
27682769
27692770 channel = self.make_request(
27722773 access_token=self.admin_user_tok,
27732774 )
27742775
2775 self.assertEqual(400, channel.code, msg=channel.json_body)
2776 self.assertEqual(HTTPStatus.BAD_REQUEST, channel.code, msg=channel.json_body)
27762777 self.assertEqual("Can only look up local users", channel.json_body["error"])
27772778
27782779 def test_limit_GET(self):
27882789 access_token=self.admin_user_tok,
27892790 )
27902791
2791 self.assertEqual(200, channel.code, msg=channel.json_body)
2792 self.assertEqual(HTTPStatus.OK, channel.code, msg=channel.json_body)
27922793 self.assertEqual(channel.json_body["total"], number_media)
27932794 self.assertEqual(len(channel.json_body["media"]), 5)
27942795 self.assertEqual(channel.json_body["next_token"], 5)
28072808 access_token=self.admin_user_tok,
28082809 )
28092810
2810 self.assertEqual(200, channel.code, msg=channel.json_body)
2811 self.assertEqual(HTTPStatus.OK, channel.code, msg=channel.json_body)
28112812 self.assertEqual(channel.json_body["total"], 5)
28122813 self.assertEqual(len(channel.json_body["deleted_media"]), 5)
28132814
28242825 access_token=self.admin_user_tok,
28252826 )
28262827
2827 self.assertEqual(200, channel.code, msg=channel.json_body)
2828 self.assertEqual(HTTPStatus.OK, channel.code, msg=channel.json_body)
28282829 self.assertEqual(channel.json_body["total"], number_media)
28292830 self.assertEqual(len(channel.json_body["media"]), 15)
28302831 self.assertNotIn("next_token", channel.json_body)
28432844 access_token=self.admin_user_tok,
28442845 )
28452846
2846 self.assertEqual(200, channel.code, msg=channel.json_body)
2847 self.assertEqual(HTTPStatus.OK, channel.code, msg=channel.json_body)
28472848 self.assertEqual(channel.json_body["total"], 15)
28482849 self.assertEqual(len(channel.json_body["deleted_media"]), 15)
28492850
28602861 access_token=self.admin_user_tok,
28612862 )
28622863
2863 self.assertEqual(200, channel.code, msg=channel.json_body)
2864 self.assertEqual(HTTPStatus.OK, channel.code, msg=channel.json_body)
28642865 self.assertEqual(channel.json_body["total"], number_media)
28652866 self.assertEqual(channel.json_body["next_token"], 15)
28662867 self.assertEqual(len(channel.json_body["media"]), 10)
28792880 access_token=self.admin_user_tok,
28802881 )
28812882
2882 self.assertEqual(200, channel.code, msg=channel.json_body)
2883 self.assertEqual(HTTPStatus.OK, channel.code, msg=channel.json_body)
28832884 self.assertEqual(channel.json_body["total"], 10)
28842885 self.assertEqual(len(channel.json_body["deleted_media"]), 10)
28852886
28932894 access_token=self.admin_user_tok,
28942895 )
28952896
2896 self.assertEqual(400, channel.code, msg=channel.json_body)
2897 self.assertEqual(HTTPStatus.BAD_REQUEST, channel.code, msg=channel.json_body)
28972898 self.assertEqual(Codes.UNKNOWN, channel.json_body["errcode"])
28982899
28992900 # invalid search order
29032904 access_token=self.admin_user_tok,
29042905 )
29052906
2906 self.assertEqual(400, channel.code, msg=channel.json_body)
2907 self.assertEqual(HTTPStatus.BAD_REQUEST, channel.code, msg=channel.json_body)
29072908 self.assertEqual(Codes.UNKNOWN, channel.json_body["errcode"])
29082909
29092910 # negative limit
29132914 access_token=self.admin_user_tok,
29142915 )
29152916
2916 self.assertEqual(400, channel.code, msg=channel.json_body)
2917 self.assertEqual(HTTPStatus.BAD_REQUEST, channel.code, msg=channel.json_body)
29172918 self.assertEqual(Codes.INVALID_PARAM, channel.json_body["errcode"])
29182919
29192920 # negative from
29232924 access_token=self.admin_user_tok,
29242925 )
29252926
2926 self.assertEqual(400, channel.code, msg=channel.json_body)
2927 self.assertEqual(HTTPStatus.BAD_REQUEST, channel.code, msg=channel.json_body)
29272928 self.assertEqual(Codes.INVALID_PARAM, channel.json_body["errcode"])
29282929
29292930 def test_next_token(self):
29462947 access_token=self.admin_user_tok,
29472948 )
29482949
2949 self.assertEqual(200, channel.code, msg=channel.json_body)
2950 self.assertEqual(HTTPStatus.OK, channel.code, msg=channel.json_body)
29502951 self.assertEqual(channel.json_body["total"], number_media)
29512952 self.assertEqual(len(channel.json_body["media"]), number_media)
29522953 self.assertNotIn("next_token", channel.json_body)
29592960 access_token=self.admin_user_tok,
29602961 )
29612962
2962 self.assertEqual(200, channel.code, msg=channel.json_body)
2963 self.assertEqual(HTTPStatus.OK, channel.code, msg=channel.json_body)
29632964 self.assertEqual(channel.json_body["total"], number_media)
29642965 self.assertEqual(len(channel.json_body["media"]), number_media)
29652966 self.assertNotIn("next_token", channel.json_body)
29722973 access_token=self.admin_user_tok,
29732974 )
29742975
2975 self.assertEqual(200, channel.code, msg=channel.json_body)
2976 self.assertEqual(HTTPStatus.OK, channel.code, msg=channel.json_body)
29762977 self.assertEqual(channel.json_body["total"], number_media)
29772978 self.assertEqual(len(channel.json_body["media"]), 19)
29782979 self.assertEqual(channel.json_body["next_token"], 19)
29862987 access_token=self.admin_user_tok,
29872988 )
29882989
2989 self.assertEqual(200, channel.code, msg=channel.json_body)
2990 self.assertEqual(HTTPStatus.OK, channel.code, msg=channel.json_body)
29902991 self.assertEqual(channel.json_body["total"], number_media)
29912992 self.assertEqual(len(channel.json_body["media"]), 1)
29922993 self.assertNotIn("next_token", channel.json_body)
30033004 access_token=self.admin_user_tok,
30043005 )
30053006
3006 self.assertEqual(200, channel.code, msg=channel.json_body)
3007 self.assertEqual(HTTPStatus.OK, channel.code, msg=channel.json_body)
30073008 self.assertEqual(0, channel.json_body["total"])
30083009 self.assertEqual(0, len(channel.json_body["media"]))
30093010
30183019 access_token=self.admin_user_tok,
30193020 )
30203021
3021 self.assertEqual(200, channel.code, msg=channel.json_body)
3022 self.assertEqual(HTTPStatus.OK, channel.code, msg=channel.json_body)
30223023 self.assertEqual(0, channel.json_body["total"])
30233024 self.assertEqual(0, len(channel.json_body["deleted_media"]))
30243025
30353036 access_token=self.admin_user_tok,
30363037 )
30373038
3038 self.assertEqual(200, channel.code, msg=channel.json_body)
3039 self.assertEqual(HTTPStatus.OK, channel.code, msg=channel.json_body)
30393040 self.assertEqual(number_media, channel.json_body["total"])
30403041 self.assertEqual(number_media, len(channel.json_body["media"]))
30413042 self.assertNotIn("next_token", channel.json_body)
30613062 access_token=self.admin_user_tok,
30623063 )
30633064
3064 self.assertEqual(200, channel.code, msg=channel.json_body)
3065 self.assertEqual(HTTPStatus.OK, channel.code, msg=channel.json_body)
30653066 self.assertEqual(number_media, channel.json_body["total"])
30663067 self.assertEqual(number_media, len(channel.json_body["deleted_media"]))
30673068 self.assertCountEqual(channel.json_body["deleted_media"], media_ids)
32063207
32073208 # Upload some media into the room
32083209 response = self.helper.upload_media(
3209 upload_resource, image_data, user_token, filename, expect_code=200
3210 upload_resource, image_data, user_token, filename, expect_code=HTTPStatus.OK
32103211 )
32113212
32123213 # Extract media ID from the response
32243225 )
32253226
32263227 self.assertEqual(
3227 200,
3228 HTTPStatus.OK,
32283229 channel.code,
32293230 msg=(
3230 f"Expected to receive a 200 on accessing media: {server_and_media_id}"
3231 f"Expected to receive a HTTPStatus.OK on accessing media: {server_and_media_id}"
32313232 ),
32323233 )
32333234
32343235 return media_id
32353236
3236 def _check_fields(self, content: JsonDict):
3237 def _check_fields(self, content: List[JsonDict]):
32373238 """Checks that the expected user attributes are present in content
32383239 Args:
32393240 content: List that is checked for content
32733274 url,
32743275 access_token=self.admin_user_tok,
32753276 )
3276 self.assertEqual(200, channel.code, msg=channel.json_body)
3277 self.assertEqual(HTTPStatus.OK, channel.code, msg=channel.json_body)
32773278 self.assertEqual(channel.json_body["total"], len(expected_media_list))
32783279
32793280 returned_order = [row["media_id"] for row in channel.json_body["media"]]
33093310 channel = self.make_request(
33103311 "POST", self.url, b"{}", access_token=self.admin_user_tok
33113312 )
3312 self.assertEqual(200, channel.code, msg=channel.json_body)
3313 self.assertEqual(HTTPStatus.OK, channel.code, msg=channel.json_body)
33133314 return channel.json_body["access_token"]
33143315
33153316 def test_no_auth(self):
33163317 """Try to login as a user without authentication."""
33173318 channel = self.make_request("POST", self.url, b"{}")
33183319
3319 self.assertEqual(401, channel.code, msg=channel.json_body)
3320 self.assertEqual(HTTPStatus.UNAUTHORIZED, channel.code, msg=channel.json_body)
33203321 self.assertEqual(Codes.MISSING_TOKEN, channel.json_body["errcode"])
33213322
33223323 def test_not_admin(self):
33253326 "POST", self.url, b"{}", access_token=self.other_user_tok
33263327 )
33273328
3328 self.assertEqual(403, channel.code, msg=channel.json_body)
3329 self.assertEqual(HTTPStatus.FORBIDDEN, channel.code, msg=channel.json_body)
33293330
33303331 def test_send_event(self):
33313332 """Test that sending event as a user works."""
33503351 channel = self.make_request(
33513352 "GET", "devices", b"{}", access_token=self.other_user_tok
33523353 )
3353 self.assertEqual(200, channel.code, msg=channel.json_body)
3354 self.assertEqual(HTTPStatus.OK, channel.code, msg=channel.json_body)
33543355
33553356 # We should only see the one device (from the login in `prepare`)
33563357 self.assertEqual(len(channel.json_body["devices"]), 1)
33623363
33633364 # Test that we can successfully make a request
33643365 channel = self.make_request("GET", "devices", b"{}", access_token=puppet_token)
3365 self.assertEqual(200, channel.code, msg=channel.json_body)
3366 self.assertEqual(HTTPStatus.OK, channel.code, msg=channel.json_body)
33663367
33673368 # Logout with the puppet token
33683369 channel = self.make_request("POST", "logout", b"{}", access_token=puppet_token)
3369 self.assertEqual(200, channel.code, msg=channel.json_body)
3370 self.assertEqual(HTTPStatus.OK, channel.code, msg=channel.json_body)
33703371
33713372 # The puppet token should no longer work
33723373 channel = self.make_request("GET", "devices", b"{}", access_token=puppet_token)
3373 self.assertEqual(401, channel.code, msg=channel.json_body)
3374 self.assertEqual(HTTPStatus.UNAUTHORIZED, channel.code, msg=channel.json_body)
33743375
33753376 # .. but the real user's tokens should still work
33763377 channel = self.make_request(
33773378 "GET", "devices", b"{}", access_token=self.other_user_tok
33783379 )
3379 self.assertEqual(200, channel.code, msg=channel.json_body)
3380 self.assertEqual(HTTPStatus.OK, channel.code, msg=channel.json_body)
33803381
33813382 def test_user_logout_all(self):
33823383 """Tests that the target user calling `/logout/all` does *not* expire
33873388
33883389 # Test that we can successfully make a request
33893390 channel = self.make_request("GET", "devices", b"{}", access_token=puppet_token)
3390 self.assertEqual(200, channel.code, msg=channel.json_body)
3391 self.assertEqual(HTTPStatus.OK, channel.code, msg=channel.json_body)
33913392
33923393 # Logout all with the real user token
33933394 channel = self.make_request(
33943395 "POST", "logout/all", b"{}", access_token=self.other_user_tok
33953396 )
3396 self.assertEqual(200, channel.code, msg=channel.json_body)
3397 self.assertEqual(HTTPStatus.OK, channel.code, msg=channel.json_body)
33973398
33983399 # The puppet token should still work
33993400 channel = self.make_request("GET", "devices", b"{}", access_token=puppet_token)
3400 self.assertEqual(200, channel.code, msg=channel.json_body)
3401 self.assertEqual(HTTPStatus.OK, channel.code, msg=channel.json_body)
34013402
34023403 # .. but the real user's tokens shouldn't
34033404 channel = self.make_request(
34043405 "GET", "devices", b"{}", access_token=self.other_user_tok
34053406 )
3406 self.assertEqual(401, channel.code, msg=channel.json_body)
3407 self.assertEqual(HTTPStatus.UNAUTHORIZED, channel.code, msg=channel.json_body)
34073408
34083409 def test_admin_logout_all(self):
34093410 """Tests that the admin user calling `/logout/all` does expire the
34143415
34153416 # Test that we can successfully make a request
34163417 channel = self.make_request("GET", "devices", b"{}", access_token=puppet_token)
3417 self.assertEqual(200, channel.code, msg=channel.json_body)
3418 self.assertEqual(HTTPStatus.OK, channel.code, msg=channel.json_body)
34183419
34193420 # Logout all with the admin user token
34203421 channel = self.make_request(
34213422 "POST", "logout/all", b"{}", access_token=self.admin_user_tok
34223423 )
3423 self.assertEqual(200, channel.code, msg=channel.json_body)
3424 self.assertEqual(HTTPStatus.OK, channel.code, msg=channel.json_body)
34243425
34253426 # The puppet token should no longer work
34263427 channel = self.make_request("GET", "devices", b"{}", access_token=puppet_token)
3427 self.assertEqual(401, channel.code, msg=channel.json_body)
3428 self.assertEqual(HTTPStatus.UNAUTHORIZED, channel.code, msg=channel.json_body)
34283429
34293430 # .. but the real user's tokens should still work
34303431 channel = self.make_request(
34313432 "GET", "devices", b"{}", access_token=self.other_user_tok
34323433 )
3433 self.assertEqual(200, channel.code, msg=channel.json_body)
3434 self.assertEqual(HTTPStatus.OK, channel.code, msg=channel.json_body)
34343435
34353436 @unittest.override_config(
34363437 {
34583459 # Now unaccept it and check that we can't send an event
34593460 self.get_success(self.store.user_set_consent_version(self.other_user, "0.0"))
34603461 self.helper.send_event(
3461 room_id, "com.example.test", tok=self.other_user_tok, expect_code=403
3462 room_id,
3463 "com.example.test",
3464 tok=self.other_user_tok,
3465 expect_code=HTTPStatus.FORBIDDEN,
34623466 )
34633467
34643468 # Login in as the user
34763480
34773481 # Trying to join as the other user should fail due to reaching MAU limit.
34783482 self.helper.join(
3479 room_id, user=self.other_user, tok=self.other_user_tok, expect_code=403
3483 room_id,
3484 user=self.other_user,
3485 tok=self.other_user_tok,
3486 expect_code=HTTPStatus.FORBIDDEN,
34803487 )
34813488
34823489 # Logging in as the other user and joining a room should work, even
35113518 Try to get information of an user without authentication.
35123519 """
35133520 channel = self.make_request("GET", self.url, b"{}")
3514 self.assertEqual(401, channel.code, msg=channel.json_body)
3521 self.assertEqual(HTTPStatus.UNAUTHORIZED, channel.code, msg=channel.json_body)
35153522 self.assertEqual(Codes.MISSING_TOKEN, channel.json_body["errcode"])
35163523
35173524 def test_requester_is_not_admin(self):
35263533 self.url,
35273534 access_token=other_user2_token,
35283535 )
3529 self.assertEqual(403, channel.code, msg=channel.json_body)
3536 self.assertEqual(HTTPStatus.FORBIDDEN, channel.code, msg=channel.json_body)
35303537 self.assertEqual(Codes.FORBIDDEN, channel.json_body["errcode"])
35313538
35323539 def test_user_is_not_local(self):
35333540 """
3534 Tests that a lookup for a user that is not a local returns a 400
3541 Tests that a lookup for a user that is not a local returns a HTTPStatus.BAD_REQUEST
35353542 """
35363543 url = self.url_prefix % "@unknown_person:unknown_domain"
35373544
35403547 url,
35413548 access_token=self.admin_user_tok,
35423549 )
3543 self.assertEqual(400, channel.code, msg=channel.json_body)
3550 self.assertEqual(HTTPStatus.BAD_REQUEST, channel.code, msg=channel.json_body)
35443551 self.assertEqual("Can only whois a local user", channel.json_body["error"])
35453552
35463553 def test_get_whois_admin(self):
35523559 self.url,
35533560 access_token=self.admin_user_tok,
35543561 )
3555 self.assertEqual(200, channel.code, msg=channel.json_body)
3562 self.assertEqual(HTTPStatus.OK, channel.code, msg=channel.json_body)
35563563 self.assertEqual(self.other_user, channel.json_body["user_id"])
35573564 self.assertIn("devices", channel.json_body)
35583565
35673574 self.url,
35683575 access_token=other_user_token,
35693576 )
3570 self.assertEqual(200, channel.code, msg=channel.json_body)
3577 self.assertEqual(HTTPStatus.OK, channel.code, msg=channel.json_body)
35713578 self.assertEqual(self.other_user, channel.json_body["user_id"])
35723579 self.assertIn("devices", channel.json_body)
35733580
35973604 Try to get information of an user without authentication.
35983605 """
35993606 channel = self.make_request(method, self.url)
3600 self.assertEqual(401, channel.code, msg=channel.json_body)
3607 self.assertEqual(HTTPStatus.UNAUTHORIZED, channel.code, msg=channel.json_body)
36013608 self.assertEqual(Codes.MISSING_TOKEN, channel.json_body["errcode"])
36023609
36033610 @parameterized.expand(["POST", "DELETE"])
36083615 other_user_token = self.login("user", "pass")
36093616
36103617 channel = self.make_request(method, self.url, access_token=other_user_token)
3611 self.assertEqual(403, channel.code, msg=channel.json_body)
3618 self.assertEqual(HTTPStatus.FORBIDDEN, channel.code, msg=channel.json_body)
36123619 self.assertEqual(Codes.FORBIDDEN, channel.json_body["errcode"])
36133620
36143621 @parameterized.expand(["POST", "DELETE"])
36153622 def test_user_is_not_local(self, method: str):
36163623 """
3617 Tests that shadow-banning for a user that is not a local returns a 400
3624 Tests that shadow-banning for a user that is not a local returns a HTTPStatus.BAD_REQUEST
36183625 """
36193626 url = "/_synapse/admin/v1/whois/@unknown_person:unknown_domain"
36203627
36213628 channel = self.make_request(method, url, access_token=self.admin_user_tok)
3622 self.assertEqual(400, channel.code, msg=channel.json_body)
3629 self.assertEqual(HTTPStatus.BAD_REQUEST, channel.code, msg=channel.json_body)
36233630
36243631 def test_success(self):
36253632 """
36313638 self.assertFalse(result.shadow_banned)
36323639
36333640 channel = self.make_request("POST", self.url, access_token=self.admin_user_tok)
3634 self.assertEqual(200, channel.code, msg=channel.json_body)
3641 self.assertEqual(HTTPStatus.OK, channel.code, msg=channel.json_body)
36353642 self.assertEqual({}, channel.json_body)
36363643
36373644 # Ensure the user is shadow-banned (and the cache was cleared).
36423649 channel = self.make_request(
36433650 "DELETE", self.url, access_token=self.admin_user_tok
36443651 )
3645 self.assertEqual(200, channel.code, msg=channel.json_body)
3652 self.assertEqual(HTTPStatus.OK, channel.code, msg=channel.json_body)
36463653 self.assertEqual({}, channel.json_body)
36473654
36483655 # Ensure the user is no longer shadow-banned (and the cache was cleared).
36763683 """
36773684 channel = self.make_request(method, self.url, b"{}")
36783685
3679 self.assertEqual(401, channel.code, msg=channel.json_body)
3686 self.assertEqual(HTTPStatus.UNAUTHORIZED, channel.code, msg=channel.json_body)
36803687 self.assertEqual(Codes.MISSING_TOKEN, channel.json_body["errcode"])
36813688
36823689 @parameterized.expand(["GET", "POST", "DELETE"])
36923699 access_token=other_user_token,
36933700 )
36943701
3695 self.assertEqual(403, channel.code, msg=channel.json_body)
3702 self.assertEqual(HTTPStatus.FORBIDDEN, channel.code, msg=channel.json_body)
36963703 self.assertEqual(Codes.FORBIDDEN, channel.json_body["errcode"])
36973704
36983705 @parameterized.expand(["GET", "POST", "DELETE"])
36993706 def test_user_does_not_exist(self, method: str):
37003707 """
3701 Tests that a lookup for a user that does not exist returns a 404
3708 Tests that a lookup for a user that does not exist returns a HTTPStatus.NOT_FOUND
37023709 """
37033710 url = "/_synapse/admin/v1/users/@unknown_person:test/override_ratelimit"
37043711
37083715 access_token=self.admin_user_tok,
37093716 )
37103717
3711 self.assertEqual(404, channel.code, msg=channel.json_body)
3718 self.assertEqual(HTTPStatus.NOT_FOUND, channel.code, msg=channel.json_body)
37123719 self.assertEqual(Codes.NOT_FOUND, channel.json_body["errcode"])
37133720
37143721 @parameterized.expand(
37203727 )
37213728 def test_user_is_not_local(self, method: str, error_msg: str):
37223729 """
3723 Tests that a lookup for a user that is not a local returns a 400
3730 Tests that a lookup for a user that is not a local returns a HTTPStatus.BAD_REQUEST
37243731 """
37253732 url = (
37263733 "/_synapse/admin/v1/users/@unknown_person:unknown_domain/override_ratelimit"
37323739 access_token=self.admin_user_tok,
37333740 )
37343741
3735 self.assertEqual(400, channel.code, msg=channel.json_body)
3742 self.assertEqual(HTTPStatus.BAD_REQUEST, channel.code, msg=channel.json_body)
37363743 self.assertEqual(error_msg, channel.json_body["error"])
37373744
37383745 def test_invalid_parameter(self):
37473754 content={"messages_per_second": "string"},
37483755 )
37493756
3750 self.assertEqual(400, channel.code, msg=channel.json_body)
3757 self.assertEqual(HTTPStatus.BAD_REQUEST, channel.code, msg=channel.json_body)
37513758 self.assertEqual(Codes.INVALID_PARAM, channel.json_body["errcode"])
37523759
37533760 # messages_per_second is negative
37583765 content={"messages_per_second": -1},
37593766 )
37603767
3761 self.assertEqual(400, channel.code, msg=channel.json_body)
3768 self.assertEqual(HTTPStatus.BAD_REQUEST, channel.code, msg=channel.json_body)
37623769 self.assertEqual(Codes.INVALID_PARAM, channel.json_body["errcode"])
37633770
37643771 # burst_count is a string
37693776 content={"burst_count": "string"},
37703777 )
37713778
3772 self.assertEqual(400, channel.code, msg=channel.json_body)
3779 self.assertEqual(HTTPStatus.BAD_REQUEST, channel.code, msg=channel.json_body)
37733780 self.assertEqual(Codes.INVALID_PARAM, channel.json_body["errcode"])
37743781
37753782 # burst_count is negative
37803787 content={"burst_count": -1},
37813788 )
37823789
3783 self.assertEqual(400, channel.code, msg=channel.json_body)
3790 self.assertEqual(HTTPStatus.BAD_REQUEST, channel.code, msg=channel.json_body)
37843791 self.assertEqual(Codes.INVALID_PARAM, channel.json_body["errcode"])
37853792
37863793 def test_return_zero_when_null(self):
38053812 self.url,
38063813 access_token=self.admin_user_tok,
38073814 )
3808 self.assertEqual(200, channel.code, msg=channel.json_body)
3815 self.assertEqual(HTTPStatus.OK, channel.code, msg=channel.json_body)
38093816 self.assertEqual(0, channel.json_body["messages_per_second"])
38103817 self.assertEqual(0, channel.json_body["burst_count"])
38113818
38193826 self.url,
38203827 access_token=self.admin_user_tok,
38213828 )
3822 self.assertEqual(200, channel.code, msg=channel.json_body)
3829 self.assertEqual(HTTPStatus.OK, channel.code, msg=channel.json_body)
38233830 self.assertNotIn("messages_per_second", channel.json_body)
38243831 self.assertNotIn("burst_count", channel.json_body)
38253832
38303837 access_token=self.admin_user_tok,
38313838 content={"messages_per_second": 10, "burst_count": 11},
38323839 )
3833 self.assertEqual(200, channel.code, msg=channel.json_body)
3840 self.assertEqual(HTTPStatus.OK, channel.code, msg=channel.json_body)
38343841 self.assertEqual(10, channel.json_body["messages_per_second"])
38353842 self.assertEqual(11, channel.json_body["burst_count"])
38363843
38413848 access_token=self.admin_user_tok,
38423849 content={"messages_per_second": 20, "burst_count": 21},
38433850 )
3844 self.assertEqual(200, channel.code, msg=channel.json_body)
3851 self.assertEqual(HTTPStatus.OK, channel.code, msg=channel.json_body)
38453852 self.assertEqual(20, channel.json_body["messages_per_second"])
38463853 self.assertEqual(21, channel.json_body["burst_count"])
38473854
38513858 self.url,
38523859 access_token=self.admin_user_tok,
38533860 )
3854 self.assertEqual(200, channel.code, msg=channel.json_body)
3861 self.assertEqual(HTTPStatus.OK, channel.code, msg=channel.json_body)
38553862 self.assertEqual(20, channel.json_body["messages_per_second"])
38563863 self.assertEqual(21, channel.json_body["burst_count"])
38573864
38613868 self.url,
38623869 access_token=self.admin_user_tok,
38633870 )
3864 self.assertEqual(200, channel.code, msg=channel.json_body)
3871 self.assertEqual(HTTPStatus.OK, channel.code, msg=channel.json_body)
38653872 self.assertNotIn("messages_per_second", channel.json_body)
38663873 self.assertNotIn("burst_count", channel.json_body)
38673874
38713878 self.url,
38723879 access_token=self.admin_user_tok,
38733880 )
3874 self.assertEqual(200, channel.code, msg=channel.json_body)
3881 self.assertEqual(HTTPStatus.OK, channel.code, msg=channel.json_body)
38753882 self.assertNotIn("messages_per_second", channel.json_body)
38763883 self.assertNotIn("burst_count", channel.json_body)
1010 # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
1111 # See the License for the specific language governing permissions and
1212 # limitations under the License.
13
14 from http import HTTPStatus
1315
1416 import synapse.rest.admin
1517 from synapse.api.errors import Codes, SynapseError
3234 async def check_username(username):
3335 if username == "allowed":
3436 return True
35 raise SynapseError(400, "User ID already taken.", errcode=Codes.USER_IN_USE)
37 raise SynapseError(
38 HTTPStatus.BAD_REQUEST,
39 "User ID already taken.",
40 errcode=Codes.USER_IN_USE,
41 )
3642
3743 handler = self.hs.get_registration_handler()
3844 handler.check_username = check_username
3945
4046 def test_username_available(self):
4147 """
42 The endpoint should return a 200 response if the username does not exist
48 The endpoint should return a HTTPStatus.OK response if the username does not exist
4349 """
4450
4551 url = "%s?username=%s" % (self.url, "allowed")
4652 channel = self.make_request("GET", url, None, self.admin_user_tok)
4753
48 self.assertEqual(200, int(channel.result["code"]), msg=channel.result["body"])
54 self.assertEqual(HTTPStatus.OK, channel.code, msg=channel.json_body)
4955 self.assertTrue(channel.json_body["available"])
5056
5157 def test_username_unavailable(self):
5258 """
53 The endpoint should return a 200 response if the username does not exist
59 The endpoint should return a HTTPStatus.OK response if the username does not exist
5460 """
5561
5662 url = "%s?username=%s" % (self.url, "disallowed")
5763 channel = self.make_request("GET", url, None, self.admin_user_tok)
5864
59 self.assertEqual(400, int(channel.result["code"]), msg=channel.result["body"])
65 self.assertEqual(
66 HTTPStatus.BAD_REQUEST,
67 channel.code,
68 msg=channel.json_body,
69 )
6070 self.assertEqual(channel.json_body["errcode"], "M_USER_IN_USE")
6171 self.assertEqual(channel.json_body["error"], "User ID already taken.")
1111 # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
1212 # See the License for the specific language governing permissions and
1313 # limitations under the License.
14 from http import HTTPStatus
1415 from typing import Optional, Union
1516
1617 from twisted.internet.defer import succeed
512513 self.user_pass = "pass"
513514 self.user = self.register_user("test", self.user_pass)
514515
516 def use_refresh_token(self, refresh_token: str) -> FakeChannel:
517 """
518 Helper that makes a request to use a refresh token.
519 """
520 return self.make_request(
521 "POST",
522 "/_matrix/client/v1/refresh",
523 {"refresh_token": refresh_token},
524 )
525
526 def is_access_token_valid(self, access_token) -> bool:
527 """
528 Checks whether an access token is valid, returning whether it is or not.
529 """
530 code = self.make_request(
531 "GET", "/_matrix/client/v3/account/whoami", access_token=access_token
532 ).code
533
534 # Either 200 or 401 is what we get back; anything else is a bug.
535 assert code in {HTTPStatus.OK, HTTPStatus.UNAUTHORIZED}
536
537 return code == HTTPStatus.OK
538
515539 def test_login_issue_refresh_token(self):
516540 """
517541 A login response should include a refresh_token only if asked.
518542 """
519543 # Test login
520 body = {"type": "m.login.password", "user": "test", "password": self.user_pass}
544 body = {
545 "type": "m.login.password",
546 "user": "test",
547 "password": self.user_pass,
548 }
521549
522550 login_without_refresh = self.make_request(
523551 "POST", "/_matrix/client/r0/login", body
527555
528556 login_with_refresh = self.make_request(
529557 "POST",
530 "/_matrix/client/r0/login?org.matrix.msc2918.refresh_token=true",
531 body,
558 "/_matrix/client/r0/login",
559 {"refresh_token": True, **body},
532560 )
533561 self.assertEqual(login_with_refresh.code, 200, login_with_refresh.result)
534562 self.assertIn("refresh_token", login_with_refresh.json_body)
554582
555583 register_with_refresh = self.make_request(
556584 "POST",
557 "/_matrix/client/r0/register?org.matrix.msc2918.refresh_token=true",
585 "/_matrix/client/r0/register",
558586 {
559587 "username": "test3",
560588 "password": self.user_pass,
561589 "auth": {"type": LoginType.DUMMY},
590 "refresh_token": True,
562591 },
563592 )
564593 self.assertEqual(register_with_refresh.code, 200, register_with_refresh.result)
569598 """
570599 A refresh token can be used to issue a new access token.
571600 """
572 body = {"type": "m.login.password", "user": "test", "password": self.user_pass}
601 body = {
602 "type": "m.login.password",
603 "user": "test",
604 "password": self.user_pass,
605 "refresh_token": True,
606 }
573607 login_response = self.make_request(
574608 "POST",
575 "/_matrix/client/r0/login?org.matrix.msc2918.refresh_token=true",
609 "/_matrix/client/r0/login",
576610 body,
577611 )
578612 self.assertEqual(login_response.code, 200, login_response.result)
579613
580614 refresh_response = self.make_request(
581615 "POST",
582 "/_matrix/client/unstable/org.matrix.msc2918.refresh_token/refresh",
616 "/_matrix/client/v1/refresh",
583617 {"refresh_token": login_response.json_body["refresh_token"]},
584618 )
585619 self.assertEqual(refresh_response.code, 200, refresh_response.result)
598632 )
599633
600634 @override_config({"refreshable_access_token_lifetime": "1m"})
601 def test_refresh_token_expiration(self):
635 def test_refreshable_access_token_expiration(self):
602636 """
603637 The access token should have some time as specified in the config.
604638 """
605 body = {"type": "m.login.password", "user": "test", "password": self.user_pass}
639 body = {
640 "type": "m.login.password",
641 "user": "test",
642 "password": self.user_pass,
643 "refresh_token": True,
644 }
606645 login_response = self.make_request(
607646 "POST",
608 "/_matrix/client/r0/login?org.matrix.msc2918.refresh_token=true",
647 "/_matrix/client/r0/login",
609648 body,
610649 )
611650 self.assertEqual(login_response.code, 200, login_response.result)
615654
616655 refresh_response = self.make_request(
617656 "POST",
618 "/_matrix/client/unstable/org.matrix.msc2918.refresh_token/refresh",
657 "/_matrix/client/v1/refresh",
619658 {"refresh_token": login_response.json_body["refresh_token"]},
620659 )
621660 self.assertEqual(refresh_response.code, 200, refresh_response.result)
622661 self.assertApproximates(
623662 refresh_response.json_body["expires_in_ms"], 60 * 1000, 100
624663 )
664 access_token = refresh_response.json_body["access_token"]
665
666 # Advance 59 seconds in the future (just shy of 1 minute, the time of expiry)
667 self.reactor.advance(59.0)
668 # Check that our token is valid
669 self.assertEqual(
670 self.make_request(
671 "GET", "/_matrix/client/v3/account/whoami", access_token=access_token
672 ).code,
673 HTTPStatus.OK,
674 )
675
676 # Advance 2 more seconds (just past the time of expiry)
677 self.reactor.advance(2.0)
678 # Check that our token is invalid
679 self.assertEqual(
680 self.make_request(
681 "GET", "/_matrix/client/v3/account/whoami", access_token=access_token
682 ).code,
683 HTTPStatus.UNAUTHORIZED,
684 )
685
686 @override_config(
687 {
688 "refreshable_access_token_lifetime": "1m",
689 "nonrefreshable_access_token_lifetime": "10m",
690 }
691 )
692 def test_different_expiry_for_refreshable_and_nonrefreshable_access_tokens(self):
693 """
694 Tests that the expiry times for refreshable and non-refreshable access
695 tokens can be different.
696 """
697 body = {
698 "type": "m.login.password",
699 "user": "test",
700 "password": self.user_pass,
701 }
702 login_response1 = self.make_request(
703 "POST",
704 "/_matrix/client/r0/login",
705 {"refresh_token": True, **body},
706 )
707 self.assertEqual(login_response1.code, 200, login_response1.result)
708 self.assertApproximates(
709 login_response1.json_body["expires_in_ms"], 60 * 1000, 100
710 )
711 refreshable_access_token = login_response1.json_body["access_token"]
712
713 login_response2 = self.make_request(
714 "POST",
715 "/_matrix/client/r0/login",
716 body,
717 )
718 self.assertEqual(login_response2.code, 200, login_response2.result)
719 nonrefreshable_access_token = login_response2.json_body["access_token"]
720
721 # Advance 59 seconds in the future (just shy of 1 minute, the time of expiry)
722 self.reactor.advance(59.0)
723
724 # Both tokens should still be valid.
725 self.assertTrue(self.is_access_token_valid(refreshable_access_token))
726 self.assertTrue(self.is_access_token_valid(nonrefreshable_access_token))
727
728 # Advance to 61 s (just past 1 minute, the time of expiry)
729 self.reactor.advance(2.0)
730
731 # Only the non-refreshable token is still valid.
732 self.assertFalse(self.is_access_token_valid(refreshable_access_token))
733 self.assertTrue(self.is_access_token_valid(nonrefreshable_access_token))
734
735 # Advance to 599 s (just shy of 10 minutes, the time of expiry)
736 self.reactor.advance(599.0 - 61.0)
737
738 # It's still the case that only the non-refreshable token is still valid.
739 self.assertFalse(self.is_access_token_valid(refreshable_access_token))
740 self.assertTrue(self.is_access_token_valid(nonrefreshable_access_token))
741
742 # Advance to 601 s (just past 10 minutes, the time of expiry)
743 self.reactor.advance(2.0)
744
745 # Now neither token is valid.
746 self.assertFalse(self.is_access_token_valid(refreshable_access_token))
747 self.assertFalse(self.is_access_token_valid(nonrefreshable_access_token))
748
749 @override_config(
750 {"refreshable_access_token_lifetime": "1m", "refresh_token_lifetime": "2m"}
751 )
752 def test_refresh_token_expiry(self):
753 """
754 The refresh token can be configured to have a limited lifetime.
755 When that lifetime has ended, the refresh token can no longer be used to
756 refresh the session.
757 """
758
759 body = {
760 "type": "m.login.password",
761 "user": "test",
762 "password": self.user_pass,
763 "refresh_token": True,
764 }
765 login_response = self.make_request(
766 "POST",
767 "/_matrix/client/r0/login",
768 body,
769 )
770 self.assertEqual(login_response.code, HTTPStatus.OK, login_response.result)
771 refresh_token1 = login_response.json_body["refresh_token"]
772
773 # Advance 119 seconds in the future (just shy of 2 minutes)
774 self.reactor.advance(119.0)
775
776 # Refresh our session. The refresh token should still JUST be valid right now.
777 # By doing so, we get a new access token and a new refresh token.
778 refresh_response = self.use_refresh_token(refresh_token1)
779 self.assertEqual(refresh_response.code, HTTPStatus.OK, refresh_response.result)
780 self.assertIn(
781 "refresh_token",
782 refresh_response.json_body,
783 "No new refresh token returned after refresh.",
784 )
785 refresh_token2 = refresh_response.json_body["refresh_token"]
786
787 # Advance 121 seconds in the future (just a bit more than 2 minutes)
788 self.reactor.advance(121.0)
789
790 # Try to refresh our session, but instead notice that the refresh token is
791 # not valid (it just expired).
792 refresh_response = self.use_refresh_token(refresh_token2)
793 self.assertEqual(
794 refresh_response.code, HTTPStatus.FORBIDDEN, refresh_response.result
795 )
796
797 @override_config(
798 {
799 "refreshable_access_token_lifetime": "2m",
800 "refresh_token_lifetime": "2m",
801 "session_lifetime": "3m",
802 }
803 )
804 def test_ultimate_session_expiry(self):
805 """
806 The session can be configured to have an ultimate, limited lifetime.
807 """
808
809 body = {
810 "type": "m.login.password",
811 "user": "test",
812 "password": self.user_pass,
813 "refresh_token": True,
814 }
815 login_response = self.make_request(
816 "POST",
817 "/_matrix/client/r0/login",
818 body,
819 )
820 self.assertEqual(login_response.code, 200, login_response.result)
821 refresh_token = login_response.json_body["refresh_token"]
822
823 # Advance shy of 2 minutes into the future
824 self.reactor.advance(119.0)
825
826 # Refresh our session. The refresh token should still be valid right now.
827 refresh_response = self.use_refresh_token(refresh_token)
828 self.assertEqual(refresh_response.code, 200, refresh_response.result)
829 self.assertIn(
830 "refresh_token",
831 refresh_response.json_body,
832 "No new refresh token returned after refresh.",
833 )
834 # Notice that our access token lifetime has been diminished to match the
835 # session lifetime.
836 # 3 minutes - 119 seconds = 61 seconds.
837 self.assertEqual(refresh_response.json_body["expires_in_ms"], 61_000)
838 refresh_token = refresh_response.json_body["refresh_token"]
839
840 # Advance 61 seconds into the future. Our session should have expired
841 # now, because we've had our 3 minutes.
842 self.reactor.advance(61.0)
843
844 # Try to issue a new, refreshed, access token.
845 # This should fail because the refresh token's lifetime has also been
846 # diminished as our session expired.
847 refresh_response = self.use_refresh_token(refresh_token)
848 self.assertEqual(refresh_response.code, 403, refresh_response.result)
625849
626850 def test_refresh_token_invalidation(self):
627851 """Refresh tokens are invalidated after first use of the next token.
639863 |-> fourth_refresh (fails)
640864 """
641865
642 body = {"type": "m.login.password", "user": "test", "password": self.user_pass}
866 body = {
867 "type": "m.login.password",
868 "user": "test",
869 "password": self.user_pass,
870 "refresh_token": True,
871 }
643872 login_response = self.make_request(
644873 "POST",
645 "/_matrix/client/r0/login?org.matrix.msc2918.refresh_token=true",
874 "/_matrix/client/r0/login",
646875 body,
647876 )
648877 self.assertEqual(login_response.code, 200, login_response.result)
650879 # This first refresh should work properly
651880 first_refresh_response = self.make_request(
652881 "POST",
653 "/_matrix/client/unstable/org.matrix.msc2918.refresh_token/refresh",
882 "/_matrix/client/v1/refresh",
654883 {"refresh_token": login_response.json_body["refresh_token"]},
655884 )
656885 self.assertEqual(
660889 # This one as well, since the token in the first one was never used
661890 second_refresh_response = self.make_request(
662891 "POST",
663 "/_matrix/client/unstable/org.matrix.msc2918.refresh_token/refresh",
892 "/_matrix/client/v1/refresh",
664893 {"refresh_token": login_response.json_body["refresh_token"]},
665894 )
666895 self.assertEqual(
670899 # This one should not, since the token from the first refresh is not valid anymore
671900 third_refresh_response = self.make_request(
672901 "POST",
673 "/_matrix/client/unstable/org.matrix.msc2918.refresh_token/refresh",
902 "/_matrix/client/v1/refresh",
674903 {"refresh_token": first_refresh_response.json_body["refresh_token"]},
675904 )
676905 self.assertEqual(
698927 # Now that the access token from the last valid refresh was used once, refreshing with the N-1 token should fail
699928 fourth_refresh_response = self.make_request(
700929 "POST",
701 "/_matrix/client/unstable/org.matrix.msc2918.refresh_token/refresh",
930 "/_matrix/client/v1/refresh",
702931 {"refresh_token": login_response.json_body["refresh_token"]},
703932 )
704933 self.assertEqual(
708937 # But refreshing from the last valid refresh token still works
709938 fifth_refresh_response = self.make_request(
710939 "POST",
711 "/_matrix/client/unstable/org.matrix.msc2918.refresh_token/refresh",
940 "/_matrix/client/v1/refresh",
712941 {"refresh_token": second_refresh_response.json_body["refresh_token"]},
713942 )
714943 self.assertEqual(
1818
1919 from synapse.api.constants import EventTypes, RelationTypes
2020 from synapse.rest import admin
21 from synapse.rest.client import login, register, relations, room
21 from synapse.rest.client import login, register, relations, room, sync
2222
2323 from tests import unittest
2424 from tests.server import FakeChannel
2828 servlets = [
2929 relations.register_servlets,
3030 room.register_servlets,
31 sync.register_servlets,
3132 login.register_servlets,
3233 register.register_servlets,
3334 admin.register_servlets_for_client_rest_resource,
453454 self.assertEquals(400, channel.code, channel.json_body)
454455
455456 @unittest.override_config({"experimental_features": {"msc3440_enabled": True}})
456 def test_aggregation_get_event(self):
457 """Test that annotations, references, and threads get correctly bundled when
458 getting the parent event.
459 """
460
457 def test_bundled_aggregations(self):
458 """Test that annotations, references, and threads get correctly bundled."""
459 # Setup by sending a variety of relations.
461460 channel = self._send_relation(RelationTypes.ANNOTATION, "m.reaction", "a")
462461 self.assertEquals(200, channel.code, channel.json_body)
463462
484483 self.assertEquals(200, channel.code, channel.json_body)
485484 thread_2 = channel.json_body["event_id"]
486485
487 channel = self.make_request(
488 "GET",
489 "/rooms/%s/event/%s" % (self.room, self.parent_id),
490 access_token=self.user_token,
491 )
492 self.assertEquals(200, channel.code, channel.json_body)
493
494 self.assertEquals(
495 channel.json_body["unsigned"].get("m.relations"),
496 {
497 RelationTypes.ANNOTATION: {
486 def assert_bundle(actual):
487 """Assert the expected values of the bundled aggregations."""
488
489 # Ensure the fields are as expected.
490 self.assertCountEqual(
491 actual.keys(),
492 (
493 RelationTypes.ANNOTATION,
494 RelationTypes.REFERENCE,
495 RelationTypes.THREAD,
496 ),
497 )
498
499 # Check the values of each field.
500 self.assertEquals(
501 {
498502 "chunk": [
499503 {"type": "m.reaction", "key": "a", "count": 2},
500504 {"type": "m.reaction", "key": "b", "count": 1},
501505 ]
502506 },
503 RelationTypes.REFERENCE: {
504 "chunk": [{"event_id": reply_1}, {"event_id": reply_2}]
507 actual[RelationTypes.ANNOTATION],
508 )
509
510 self.assertEquals(
511 {"chunk": [{"event_id": reply_1}, {"event_id": reply_2}]},
512 actual[RelationTypes.REFERENCE],
513 )
514
515 self.assertEquals(
516 2,
517 actual[RelationTypes.THREAD].get("count"),
518 )
519 # The latest thread event has some fields that don't matter.
520 self.assert_dict(
521 {
522 "content": {
523 "m.relates_to": {
524 "event_id": self.parent_id,
525 "rel_type": RelationTypes.THREAD,
526 }
527 },
528 "event_id": thread_2,
529 "room_id": self.room,
530 "sender": self.user_id,
531 "type": "m.room.test",
532 "user_id": self.user_id,
505533 },
506 RelationTypes.THREAD: {
507 "count": 2,
508 "latest_event": {
509 "age": 100,
510 "content": {
511 "m.relates_to": {
512 "event_id": self.parent_id,
513 "rel_type": RelationTypes.THREAD,
514 }
515 },
516 "event_id": thread_2,
517 "origin_server_ts": 1600,
518 "room_id": self.room,
519 "sender": self.user_id,
520 "type": "m.room.test",
521 "unsigned": {"age": 100},
522 "user_id": self.user_id,
523 },
534 actual[RelationTypes.THREAD].get("latest_event"),
535 )
536
537 def _find_and_assert_event(events):
538 """
539 Find the parent event in a chunk of events and assert that it has the proper bundled aggregations.
540 """
541 for event in events:
542 if event["event_id"] == self.parent_id:
543 break
544 else:
545 raise AssertionError(f"Event {self.parent_id} not found in chunk")
546 assert_bundle(event["unsigned"].get("m.relations"))
547
548 # Request the event directly.
549 channel = self.make_request(
550 "GET",
551 f"/rooms/{self.room}/event/{self.parent_id}",
552 access_token=self.user_token,
553 )
554 self.assertEquals(200, channel.code, channel.json_body)
555 assert_bundle(channel.json_body["unsigned"].get("m.relations"))
556
557 # Request the room messages.
558 channel = self.make_request(
559 "GET",
560 f"/rooms/{self.room}/messages?dir=b",
561 access_token=self.user_token,
562 )
563 self.assertEquals(200, channel.code, channel.json_body)
564 _find_and_assert_event(channel.json_body["chunk"])
565
566 # Request the room context.
567 channel = self.make_request(
568 "GET",
569 f"/rooms/{self.room}/context/{self.parent_id}",
570 access_token=self.user_token,
571 )
572 self.assertEquals(200, channel.code, channel.json_body)
573 assert_bundle(channel.json_body["event"]["unsigned"].get("m.relations"))
574
575 # Request sync.
576 channel = self.make_request("GET", "/sync", access_token=self.user_token)
577 self.assertEquals(200, channel.code, channel.json_body)
578 room_timeline = channel.json_body["rooms"]["join"][self.room]["timeline"]
579 self.assertTrue(room_timeline["limited"])
580 _find_and_assert_event(room_timeline["events"])
581
582 # Note that /relations is tested separately in test_aggregation_get_event_for_thread
583 # since it needs different data configured.
584
585 def test_aggregation_get_event_for_annotation(self):
586 """Test that annotations do not get bundled aggregations included
587 when directly requested.
588 """
589 channel = self._send_relation(RelationTypes.ANNOTATION, "m.reaction", "a")
590 self.assertEquals(200, channel.code, channel.json_body)
591 annotation_id = channel.json_body["event_id"]
592
593 # Annotate the annotation.
594 channel = self._send_relation(
595 RelationTypes.ANNOTATION, "m.reaction", "a", parent_id=annotation_id
596 )
597 self.assertEquals(200, channel.code, channel.json_body)
598
599 channel = self.make_request(
600 "GET",
601 f"/rooms/{self.room}/event/{annotation_id}",
602 access_token=self.user_token,
603 )
604 self.assertEquals(200, channel.code, channel.json_body)
605 self.assertIsNone(channel.json_body["unsigned"].get("m.relations"))
606
607 def test_aggregation_get_event_for_thread(self):
608 """Test that threads get bundled aggregations included when directly requested."""
609 channel = self._send_relation(RelationTypes.THREAD, "m.room.test")
610 self.assertEquals(200, channel.code, channel.json_body)
611 thread_id = channel.json_body["event_id"]
612
613 # Annotate the annotation.
614 channel = self._send_relation(
615 RelationTypes.ANNOTATION, "m.reaction", "a", parent_id=thread_id
616 )
617 self.assertEquals(200, channel.code, channel.json_body)
618
619 channel = self.make_request(
620 "GET",
621 f"/rooms/{self.room}/event/{thread_id}",
622 access_token=self.user_token,
623 )
624 self.assertEquals(200, channel.code, channel.json_body)
625 self.assertEquals(
626 channel.json_body["unsigned"].get("m.relations"),
627 {
628 RelationTypes.ANNOTATION: {
629 "chunk": [{"count": 1, "key": "a", "type": "m.reaction"}]
630 },
631 },
632 )
633
634 # It should also be included when the entire thread is requested.
635 channel = self.make_request(
636 "GET",
637 f"/_matrix/client/unstable/rooms/{self.room}/relations/{self.parent_id}?limit=1",
638 access_token=self.user_token,
639 )
640 self.assertEquals(200, channel.code, channel.json_body)
641 self.assertEqual(len(channel.json_body["chunk"]), 1)
642
643 thread_message = channel.json_body["chunk"][0]
644 self.assertEquals(
645 thread_message["unsigned"].get("m.relations"),
646 {
647 RelationTypes.ANNOTATION: {
648 "chunk": [{"count": 1, "key": "a", "type": "m.reaction"}]
524649 },
525650 },
526651 )
671796 {"event_id": edit_event_id, "sender": self.user_id}, m_replace_dict
672797 )
673798
799 def test_edit_edit(self):
800 """Test that an edit cannot be edited."""
801 new_body = {"msgtype": "m.text", "body": "Initial edit"}
802 channel = self._send_relation(
803 RelationTypes.REPLACE,
804 "m.room.message",
805 content={
806 "msgtype": "m.text",
807 "body": "Wibble",
808 "m.new_content": new_body,
809 },
810 )
811 self.assertEquals(200, channel.code, channel.json_body)
812 edit_event_id = channel.json_body["event_id"]
813
814 # Edit the edit event.
815 channel = self._send_relation(
816 RelationTypes.REPLACE,
817 "m.room.message",
818 content={
819 "msgtype": "m.text",
820 "body": "foo",
821 "m.new_content": {"msgtype": "m.text", "body": "Ignored edit"},
822 },
823 parent_id=edit_event_id,
824 )
825 self.assertEquals(200, channel.code, channel.json_body)
826
827 # Request the original event.
828 channel = self.make_request(
829 "GET",
830 "/rooms/%s/event/%s" % (self.room, self.parent_id),
831 access_token=self.user_token,
832 )
833 self.assertEquals(200, channel.code, channel.json_body)
834 # The edit to the edit should be ignored.
835 self.assertEquals(channel.json_body["content"], new_body)
836
837 # The relations information should not include the edit to the edit.
838 relations_dict = channel.json_body["unsigned"].get("m.relations")
839 self.assertIn(RelationTypes.REPLACE, relations_dict)
840
841 m_replace_dict = relations_dict[RelationTypes.REPLACE]
842 for key in ["event_id", "sender", "origin_server_ts"]:
843 self.assertIn(key, m_replace_dict)
844
845 self.assert_dict(
846 {"event_id": edit_event_id, "sender": self.user_id}, m_replace_dict
847 )
848
674849 def test_relations_redaction_redacts_edits(self):
675850 """Test that edits of an event are redacted when the original event
676851 is redacted.
1111 # See the License for the specific language governing permissions and
1212 # limitations under the License.
1313 import inspect
14 import os
1415 from typing import Iterable
1516
16 from synapse.rest.media.v1.filepath import MediaFilePaths
17 from synapse.rest.media.v1.filepath import MediaFilePaths, _wrap_with_jail_check
1718
1819 from tests import unittest
1920
485486 f"{value!r} unexpectedly passed validation: "
486487 f"{method} returned {path_or_list!r}"
487488 )
489
490
491 class MediaFilePathsJailTestCase(unittest.TestCase):
492 def _check_relative_path(self, filepaths: MediaFilePaths, path: str) -> None:
493 """Passes a relative path through the jail check.
494
495 Args:
496 filepaths: The `MediaFilePaths` instance.
497 path: A path relative to the media store directory.
498
499 Raises:
500 ValueError: If the jail check fails.
501 """
502
503 @_wrap_with_jail_check(relative=True)
504 def _make_relative_path(self: MediaFilePaths, path: str) -> str:
505 return path
506
507 _make_relative_path(filepaths, path)
508
509 def _check_absolute_path(self, filepaths: MediaFilePaths, path: str) -> None:
510 """Passes an absolute path through the jail check.
511
512 Args:
513 filepaths: The `MediaFilePaths` instance.
514 path: A path relative to the media store directory.
515
516 Raises:
517 ValueError: If the jail check fails.
518 """
519
520 @_wrap_with_jail_check(relative=False)
521 def _make_absolute_path(self: MediaFilePaths, path: str) -> str:
522 return os.path.join(self.base_path, path)
523
524 _make_absolute_path(filepaths, path)
525
526 def test_traversal_inside(self) -> None:
527 """Test the jail check for paths that stay within the media directory."""
528 # Despite the `../`s, these paths still lie within the media directory and it's
529 # expected for the jail check to allow them through.
530 # These paths ought to trip the other checks in place and should never be
531 # returned.
532 filepaths = MediaFilePaths("/media_store")
533 path = "url_cache/2020-01-02/../../GerZNDnDZVjsOtar"
534 self._check_relative_path(filepaths, path)
535 self._check_absolute_path(filepaths, path)
536
537 def test_traversal_outside(self) -> None:
538 """Test that the jail check fails for paths that escape the media directory."""
539 filepaths = MediaFilePaths("/media_store")
540 path = "url_cache/2020-01-02/../../../GerZNDnDZVjsOtar"
541 with self.assertRaises(ValueError):
542 self._check_relative_path(filepaths, path)
543 with self.assertRaises(ValueError):
544 self._check_absolute_path(filepaths, path)
545
546 def test_traversal_reentry(self) -> None:
547 """Test the jail check for paths that exit and re-enter the media directory."""
548 # These paths lie outside the media directory if it is a symlink, and inside
549 # otherwise. Ideally the check should fail, but this proves difficult.
550 # This test documents the behaviour for this edge case.
551 # These paths ought to trip the other checks in place and should never be
552 # returned.
553 filepaths = MediaFilePaths("/media_store")
554 path = "url_cache/2020-01-02/../../../media_store/GerZNDnDZVjsOtar"
555 self._check_relative_path(filepaths, path)
556 self._check_absolute_path(filepaths, path)
557
558 def test_symlink(self) -> None:
559 """Test that a symlink does not cause the jail check to fail."""
560 media_store_path = self.mktemp()
561
562 # symlink the media store directory
563 os.symlink("/mnt/synapse/media_store", media_store_path)
564
565 # Test that relative and absolute paths don't trip the check
566 # NB: `media_store_path` is a relative path
567 filepaths = MediaFilePaths(media_store_path)
568 self._check_relative_path(filepaths, "url_cache/2020-01-02/GerZNDnDZVjsOtar")
569 self._check_absolute_path(filepaths, "url_cache/2020-01-02/GerZNDnDZVjsOtar")
570
571 filepaths = MediaFilePaths(os.path.abspath(media_store_path))
572 self._check_relative_path(filepaths, "url_cache/2020-01-02/GerZNDnDZVjsOtar")
573 self._check_absolute_path(filepaths, "url_cache/2020-01-02/GerZNDnDZVjsOtar")
574
575 def test_symlink_subdirectory(self) -> None:
576 """Test that a symlinked subdirectory does not cause the jail check to fail."""
577 media_store_path = self.mktemp()
578 os.mkdir(media_store_path)
579
580 # symlink `url_cache/`
581 os.symlink(
582 "/mnt/synapse/media_store_url_cache",
583 os.path.join(media_store_path, "url_cache"),
584 )
585
586 # Test that relative and absolute paths don't trip the check
587 # NB: `media_store_path` is a relative path
588 filepaths = MediaFilePaths(media_store_path)
589 self._check_relative_path(filepaths, "url_cache/2020-01-02/GerZNDnDZVjsOtar")
590 self._check_absolute_path(filepaths, "url_cache/2020-01-02/GerZNDnDZVjsOtar")
591
592 filepaths = MediaFilePaths(os.path.abspath(media_store_path))
593 self._check_relative_path(filepaths, "url_cache/2020-01-02/GerZNDnDZVjsOtar")
594 self._check_absolute_path(filepaths, "url_cache/2020-01-02/GerZNDnDZVjsOtar")
1111 # See the License for the specific language governing permissions and
1212 # limitations under the License.
1313 import json
14
14 from contextlib import contextmanager
15 from typing import Generator
16
17 from twisted.enterprise.adbapi import ConnectionPool
18 from twisted.internet.defer import ensureDeferred
19 from twisted.test.proto_helpers import MemoryReactor
20
21 from synapse.api.room_versions import EventFormatVersions, RoomVersions
1522 from synapse.logging.context import LoggingContext
1623 from synapse.rest import admin
1724 from synapse.rest.client import login, room
18 from synapse.storage.databases.main.events_worker import EventsWorkerStore
25 from synapse.server import HomeServer
26 from synapse.storage.databases.main.events_worker import (
27 EVENT_QUEUE_THREADS,
28 EventsWorkerStore,
29 )
30 from synapse.storage.types import Connection
31 from synapse.util import Clock
1932 from synapse.util.async_helpers import yieldable_gather_results
2033
2134 from tests import unittest
143156
144157 # We should have fetched the event from the DB
145158 self.assertEqual(ctx.get_resource_usage().evt_db_fetch_count, 1)
159
160
161 class DatabaseOutageTestCase(unittest.HomeserverTestCase):
162 """Test event fetching during a database outage."""
163
164 def prepare(self, reactor: MemoryReactor, clock: Clock, hs: HomeServer):
165 self.store: EventsWorkerStore = hs.get_datastore()
166
167 self.room_id = f"!room:{hs.hostname}"
168 self.event_ids = [f"event{i}" for i in range(20)]
169
170 self._populate_events()
171
172 def _populate_events(self) -> None:
173 """Ensure that there are test events in the database.
174
175 When testing with the in-memory SQLite database, all the events are lost during
176 the simulated outage.
177
178 To ensure consistency between `room_id`s and `event_id`s before and after the
179 outage, rows are built and inserted manually.
180
181 Upserts are used to handle the non-SQLite case where events are not lost.
182 """
183 self.get_success(
184 self.store.db_pool.simple_upsert(
185 "rooms",
186 {"room_id": self.room_id},
187 {"room_version": RoomVersions.V4.identifier},
188 )
189 )
190
191 self.event_ids = [f"event{i}" for i in range(20)]
192 for idx, event_id in enumerate(self.event_ids):
193 self.get_success(
194 self.store.db_pool.simple_upsert(
195 "events",
196 {"event_id": event_id},
197 {
198 "event_id": event_id,
199 "room_id": self.room_id,
200 "topological_ordering": idx,
201 "stream_ordering": idx,
202 "type": "test",
203 "processed": True,
204 "outlier": False,
205 },
206 )
207 )
208 self.get_success(
209 self.store.db_pool.simple_upsert(
210 "event_json",
211 {"event_id": event_id},
212 {
213 "room_id": self.room_id,
214 "json": json.dumps({"type": "test", "room_id": self.room_id}),
215 "internal_metadata": "{}",
216 "format_version": EventFormatVersions.V3,
217 },
218 )
219 )
220
221 @contextmanager
222 def _outage(self) -> Generator[None, None, None]:
223 """Simulate a database outage.
224
225 Returns:
226 A context manager. While the context is active, any attempts to connect to
227 the database will fail.
228 """
229 connection_pool = self.store.db_pool._db_pool
230
231 # Close all connections and shut down the database `ThreadPool`.
232 connection_pool.close()
233
234 # Restart the database `ThreadPool`.
235 connection_pool.start()
236
237 original_connection_factory = connection_pool.connectionFactory
238
239 def connection_factory(_pool: ConnectionPool) -> Connection:
240 raise Exception("Could not connect to the database.")
241
242 connection_pool.connectionFactory = connection_factory # type: ignore[assignment]
243 try:
244 yield
245 finally:
246 connection_pool.connectionFactory = original_connection_factory
247
248 # If the in-memory SQLite database is being used, all the events are gone.
249 # Restore the test data.
250 self._populate_events()
251
252 def test_failure(self) -> None:
253 """Test that event fetches do not get stuck during a database outage."""
254 with self._outage():
255 failure = self.get_failure(
256 self.store.get_event(self.event_ids[0]), Exception
257 )
258 self.assertEqual(str(failure.value), "Could not connect to the database.")
259
260 def test_recovery(self) -> None:
261 """Test that event fetchers recover after a database outage."""
262 with self._outage():
263 # Kick off a bunch of event fetches but do not pump the reactor
264 event_deferreds = []
265 for event_id in self.event_ids:
266 event_deferreds.append(ensureDeferred(self.store.get_event(event_id)))
267
268 # We should have maxed out on event fetcher threads
269 self.assertEqual(self.store._event_fetch_ongoing, EVENT_QUEUE_THREADS)
270
271 # All the event fetchers will fail
272 self.pump()
273 self.assertEqual(self.store._event_fetch_ongoing, 0)
274
275 for event_deferred in event_deferreds:
276 failure = self.get_failure(event_deferred, Exception)
277 self.assertEqual(
278 str(failure.value), "Could not connect to the database."
279 )
280
281 # This next event fetch should succeed
282 self.get_success(self.store.get_event(self.event_ids[0]))
1313 import json
1414 import os
1515 import tempfile
16 from typing import List, Optional, cast
1617 from unittest.mock import Mock
1718
1819 import yaml
1920
2021 from twisted.internet import defer
22 from twisted.test.proto_helpers import MemoryReactor
2123
2224 from synapse.appservice import ApplicationService, ApplicationServiceState
2325 from synapse.config._base import ConfigError
26 from synapse.events import EventBase
27 from synapse.server import HomeServer
2428 from synapse.storage.database import DatabasePool, make_conn
2529 from synapse.storage.databases.main.appservice import (
2630 ApplicationServiceStore,
2731 ApplicationServiceTransactionStore,
2832 )
33 from synapse.util import Clock
2934
3035 from tests import unittest
3136 from tests.test_utils import make_awaitable
32 from tests.utils import setup_test_homeserver
33
34
35 class ApplicationServiceStoreTestCase(unittest.TestCase):
36 @defer.inlineCallbacks
37
38
39 class ApplicationServiceStoreTestCase(unittest.HomeserverTestCase):
3740 def setUp(self):
38 self.as_yaml_files = []
39 hs = yield setup_test_homeserver(
40 self.addCleanup, federation_sender=Mock(), federation_client=Mock()
41 )
42
43 hs.config.appservice.app_service_config_files = self.as_yaml_files
44 hs.config.caches.event_cache_size = 1
41 super(ApplicationServiceStoreTestCase, self).setUp()
42
43 self.as_yaml_files: List[str] = []
44
45 self.hs.config.appservice.app_service_config_files = self.as_yaml_files
46 self.hs.config.caches.event_cache_size = 1
4547
4648 self.as_token = "token1"
4749 self.as_url = "some_url"
5254 self._add_appservice("token2", "as2", "some_url", "some_hs_token", "bob")
5355 self._add_appservice("token3", "as3", "some_url", "some_hs_token", "bob")
5456 # must be done after inserts
55 database = hs.get_datastores().databases[0]
57 database = self.hs.get_datastores().databases[0]
5658 self.store = ApplicationServiceStore(
57 database, make_conn(database._database_config, database.engine, "test"), hs
58 )
59
60 def tearDown(self):
59 database,
60 make_conn(database._database_config, database.engine, "test"),
61 self.hs,
62 )
63
64 def tearDown(self) -> None:
6165 # TODO: suboptimal that we need to create files for tests!
6266 for f in self.as_yaml_files:
6367 try:
6569 except Exception:
6670 pass
6771
68 def _add_appservice(self, as_token, id, url, hs_token, sender):
72 super(ApplicationServiceStoreTestCase, self).tearDown()
73
74 def _add_appservice(self, as_token, id, url, hs_token, sender) -> None:
6975 as_yaml = {
7076 "url": url,
7177 "as_token": as_token,
7985 outfile.write(yaml.dump(as_yaml))
8086 self.as_yaml_files.append(as_token)
8187
82 def test_retrieve_unknown_service_token(self):
88 def test_retrieve_unknown_service_token(self) -> None:
8389 service = self.store.get_app_service_by_token("invalid_token")
8490 self.assertEquals(service, None)
8591
86 def test_retrieval_of_service(self):
92 def test_retrieval_of_service(self) -> None:
8793 stored_service = self.store.get_app_service_by_token(self.as_token)
94 assert stored_service is not None
8895 self.assertEquals(stored_service.token, self.as_token)
8996 self.assertEquals(stored_service.id, self.as_id)
9097 self.assertEquals(stored_service.url, self.as_url)
9299 self.assertEquals(stored_service.namespaces[ApplicationService.NS_ROOMS], [])
93100 self.assertEquals(stored_service.namespaces[ApplicationService.NS_USERS], [])
94101
95 def test_retrieval_of_all_services(self):
102 def test_retrieval_of_all_services(self) -> None:
96103 services = self.store.get_app_services()
97104 self.assertEquals(len(services), 3)
98105
99106
100 class ApplicationServiceTransactionStoreTestCase(unittest.TestCase):
101 @defer.inlineCallbacks
102 def setUp(self):
103 self.as_yaml_files = []
104
105 hs = yield setup_test_homeserver(
106 self.addCleanup, federation_sender=Mock(), federation_client=Mock()
107 )
108
109 hs.config.appservice.app_service_config_files = self.as_yaml_files
110 hs.config.caches.event_cache_size = 1
107 class ApplicationServiceTransactionStoreTestCase(unittest.HomeserverTestCase):
108 def setUp(self) -> None:
109 super(ApplicationServiceTransactionStoreTestCase, self).setUp()
110 self.as_yaml_files: List[str] = []
111
112 self.hs.config.appservice.app_service_config_files = self.as_yaml_files
113 self.hs.config.caches.event_cache_size = 1
111114
112115 self.as_list = [
113116 {"token": "token1", "url": "https://matrix-as.org", "id": "id_1"},
116119 {"token": "gamma_tok", "url": "https://gamma.com", "id": "id_gamma"},
117120 ]
118121 for s in self.as_list:
119 yield self._add_service(s["url"], s["token"], s["id"])
122 self._add_service(s["url"], s["token"], s["id"])
120123
121124 self.as_yaml_files = []
122125
123126 # We assume there is only one database in these tests
124 database = hs.get_datastores().databases[0]
127 database = self.hs.get_datastores().databases[0]
125128 self.db_pool = database._db_pool
126129 self.engine = database.engine
127130
128 db_config = hs.config.database.get_single_database()
131 db_config = self.hs.config.database.get_single_database()
129132 self.store = TestTransactionStore(
130 database, make_conn(db_config, self.engine, "test"), hs
131 )
132
133 def _add_service(self, url, as_token, id):
133 database, make_conn(db_config, self.engine, "test"), self.hs
134 )
135
136 def _add_service(self, url, as_token, id) -> None:
134137 as_yaml = {
135138 "url": url,
136139 "as_token": as_token,
144147 outfile.write(yaml.dump(as_yaml))
145148 self.as_yaml_files.append(as_token)
146149
147 def _set_state(self, id, state, txn=None):
150 def _set_state(
151 self, id: str, state: ApplicationServiceState, txn: Optional[int] = None
152 ):
148153 return self.db_pool.runOperation(
149154 self.engine.convert_param_style(
150155 "INSERT INTO application_services_state(as_id, state, last_txn) "
151156 "VALUES(?,?,?)"
152157 ),
153 (id, state, txn),
158 (id, state.value, txn),
154159 )
155160
156161 def _insert_txn(self, as_id, txn_id, events):
168173 "INSERT INTO application_services_state(as_id, last_txn, state) "
169174 "VALUES(?,?,?)"
170175 ),
171 (as_id, txn_id, ApplicationServiceState.UP),
172 )
173
174 @defer.inlineCallbacks
175 def test_get_appservice_state_none(self):
176 (as_id, txn_id, ApplicationServiceState.UP.value),
177 )
178
179 def test_get_appservice_state_none(
180 self,
181 ) -> None:
176182 service = Mock(id="999")
177 state = yield defer.ensureDeferred(self.store.get_appservice_state(service))
183 state = self.get_success(self.store.get_appservice_state(service))
178184 self.assertEquals(None, state)
179185
180 @defer.inlineCallbacks
181 def test_get_appservice_state_up(self):
182 yield self._set_state(self.as_list[0]["id"], ApplicationServiceState.UP)
183 service = Mock(id=self.as_list[0]["id"])
184 state = yield defer.ensureDeferred(self.store.get_appservice_state(service))
186 def test_get_appservice_state_up(
187 self,
188 ) -> None:
189 self.get_success(
190 self._set_state(self.as_list[0]["id"], ApplicationServiceState.UP)
191 )
192 service = Mock(id=self.as_list[0]["id"])
193 state = self.get_success(
194 defer.ensureDeferred(self.store.get_appservice_state(service))
195 )
185196 self.assertEquals(ApplicationServiceState.UP, state)
186197
187 @defer.inlineCallbacks
188 def test_get_appservice_state_down(self):
189 yield self._set_state(self.as_list[0]["id"], ApplicationServiceState.UP)
190 yield self._set_state(self.as_list[1]["id"], ApplicationServiceState.DOWN)
191 yield self._set_state(self.as_list[2]["id"], ApplicationServiceState.DOWN)
198 def test_get_appservice_state_down(
199 self,
200 ) -> None:
201 self.get_success(
202 self._set_state(self.as_list[0]["id"], ApplicationServiceState.UP)
203 )
204 self.get_success(
205 self._set_state(self.as_list[1]["id"], ApplicationServiceState.DOWN)
206 )
207 self.get_success(
208 self._set_state(self.as_list[2]["id"], ApplicationServiceState.DOWN)
209 )
192210 service = Mock(id=self.as_list[1]["id"])
193 state = yield defer.ensureDeferred(self.store.get_appservice_state(service))
211 state = self.get_success(self.store.get_appservice_state(service))
194212 self.assertEquals(ApplicationServiceState.DOWN, state)
195213
196 @defer.inlineCallbacks
197 def test_get_appservices_by_state_none(self):
198 services = yield defer.ensureDeferred(
214 def test_get_appservices_by_state_none(
215 self,
216 ) -> None:
217 services = self.get_success(
199218 self.store.get_appservices_by_state(ApplicationServiceState.DOWN)
200219 )
201220 self.assertEquals(0, len(services))
202221
203 @defer.inlineCallbacks
204 def test_set_appservices_state_down(self):
222 def test_set_appservices_state_down(
223 self,
224 ) -> None:
205225 service = Mock(id=self.as_list[1]["id"])
206 yield defer.ensureDeferred(
226 self.get_success(
207227 self.store.set_appservice_state(service, ApplicationServiceState.DOWN)
208228 )
209 rows = yield self.db_pool.runQuery(
210 self.engine.convert_param_style(
211 "SELECT as_id FROM application_services_state WHERE state=?"
212 ),
213 (ApplicationServiceState.DOWN,),
229 rows = self.get_success(
230 self.db_pool.runQuery(
231 self.engine.convert_param_style(
232 "SELECT as_id FROM application_services_state WHERE state=?"
233 ),
234 (ApplicationServiceState.DOWN.value,),
235 )
214236 )
215237 self.assertEquals(service.id, rows[0][0])
216238
217 @defer.inlineCallbacks
218 def test_set_appservices_state_multiple_up(self):
239 def test_set_appservices_state_multiple_up(
240 self,
241 ) -> None:
219242 service = Mock(id=self.as_list[1]["id"])
220 yield defer.ensureDeferred(
243 self.get_success(
221244 self.store.set_appservice_state(service, ApplicationServiceState.UP)
222245 )
223 yield defer.ensureDeferred(
246 self.get_success(
224247 self.store.set_appservice_state(service, ApplicationServiceState.DOWN)
225248 )
226 yield defer.ensureDeferred(
249 self.get_success(
227250 self.store.set_appservice_state(service, ApplicationServiceState.UP)
228251 )
229 rows = yield self.db_pool.runQuery(
230 self.engine.convert_param_style(
231 "SELECT as_id FROM application_services_state WHERE state=?"
232 ),
233 (ApplicationServiceState.UP,),
252 rows = self.get_success(
253 self.db_pool.runQuery(
254 self.engine.convert_param_style(
255 "SELECT as_id FROM application_services_state WHERE state=?"
256 ),
257 (ApplicationServiceState.UP.value,),
258 )
234259 )
235260 self.assertEquals(service.id, rows[0][0])
236261
237 @defer.inlineCallbacks
238 def test_create_appservice_txn_first(self):
239 service = Mock(id=self.as_list[0]["id"])
240 events = [Mock(event_id="e1"), Mock(event_id="e2")]
241 txn = yield defer.ensureDeferred(
242 self.store.create_appservice_txn(service, events, [])
262 def test_create_appservice_txn_first(
263 self,
264 ) -> None:
265 service = Mock(id=self.as_list[0]["id"])
266 events = cast(List[EventBase], [Mock(event_id="e1"), Mock(event_id="e2")])
267 txn = self.get_success(
268 defer.ensureDeferred(self.store.create_appservice_txn(service, events, []))
243269 )
244270 self.assertEquals(txn.id, 1)
245271 self.assertEquals(txn.events, events)
246272 self.assertEquals(txn.service, service)
247273
248 @defer.inlineCallbacks
249 def test_create_appservice_txn_older_last_txn(self):
250 service = Mock(id=self.as_list[0]["id"])
251 events = [Mock(event_id="e1"), Mock(event_id="e2")]
252 yield self._set_last_txn(service.id, 9643) # AS is falling behind
253 yield self._insert_txn(service.id, 9644, events)
254 yield self._insert_txn(service.id, 9645, events)
255 txn = yield defer.ensureDeferred(
256 self.store.create_appservice_txn(service, events, [])
257 )
274 def test_create_appservice_txn_older_last_txn(
275 self,
276 ) -> None:
277 service = Mock(id=self.as_list[0]["id"])
278 events = cast(List[EventBase], [Mock(event_id="e1"), Mock(event_id="e2")])
279 self.get_success(self._set_last_txn(service.id, 9643)) # AS is falling behind
280 self.get_success(self._insert_txn(service.id, 9644, events))
281 self.get_success(self._insert_txn(service.id, 9645, events))
282 txn = self.get_success(self.store.create_appservice_txn(service, events, []))
258283 self.assertEquals(txn.id, 9646)
259284 self.assertEquals(txn.events, events)
260285 self.assertEquals(txn.service, service)
261286
262 @defer.inlineCallbacks
263 def test_create_appservice_txn_up_to_date_last_txn(self):
264 service = Mock(id=self.as_list[0]["id"])
265 events = [Mock(event_id="e1"), Mock(event_id="e2")]
266 yield self._set_last_txn(service.id, 9643)
267 txn = yield defer.ensureDeferred(
268 self.store.create_appservice_txn(service, events, [])
269 )
287 def test_create_appservice_txn_up_to_date_last_txn(
288 self,
289 ) -> None:
290 service = Mock(id=self.as_list[0]["id"])
291 events = cast(List[EventBase], [Mock(event_id="e1"), Mock(event_id="e2")])
292 self.get_success(self._set_last_txn(service.id, 9643))
293 txn = self.get_success(self.store.create_appservice_txn(service, events, []))
270294 self.assertEquals(txn.id, 9644)
271295 self.assertEquals(txn.events, events)
272296 self.assertEquals(txn.service, service)
273297
274 @defer.inlineCallbacks
275 def test_create_appservice_txn_up_fuzzing(self):
276 service = Mock(id=self.as_list[0]["id"])
277 events = [Mock(event_id="e1"), Mock(event_id="e2")]
278 yield self._set_last_txn(service.id, 9643)
298 def test_create_appservice_txn_up_fuzzing(
299 self,
300 ) -> None:
301 service = Mock(id=self.as_list[0]["id"])
302 events = cast(List[EventBase], [Mock(event_id="e1"), Mock(event_id="e2")])
303 self.get_success(self._set_last_txn(service.id, 9643))
279304
280305 # dump in rows with higher IDs to make sure the queries aren't wrong.
281 yield self._set_last_txn(self.as_list[1]["id"], 119643)
282 yield self._set_last_txn(self.as_list[2]["id"], 9)
283 yield self._set_last_txn(self.as_list[3]["id"], 9643)
284 yield self._insert_txn(self.as_list[1]["id"], 119644, events)
285 yield self._insert_txn(self.as_list[1]["id"], 119645, events)
286 yield self._insert_txn(self.as_list[1]["id"], 119646, events)
287 yield self._insert_txn(self.as_list[2]["id"], 10, events)
288 yield self._insert_txn(self.as_list[3]["id"], 9643, events)
289
290 txn = yield defer.ensureDeferred(
291 self.store.create_appservice_txn(service, events, [])
292 )
306 self.get_success(self._set_last_txn(self.as_list[1]["id"], 119643))
307 self.get_success(self._set_last_txn(self.as_list[2]["id"], 9))
308 self.get_success(self._set_last_txn(self.as_list[3]["id"], 9643))
309 self.get_success(self._insert_txn(self.as_list[1]["id"], 119644, events))
310 self.get_success(self._insert_txn(self.as_list[1]["id"], 119645, events))
311 self.get_success(self._insert_txn(self.as_list[1]["id"], 119646, events))
312 self.get_success(self._insert_txn(self.as_list[2]["id"], 10, events))
313 self.get_success(self._insert_txn(self.as_list[3]["id"], 9643, events))
314
315 txn = self.get_success(self.store.create_appservice_txn(service, events, []))
293316 self.assertEquals(txn.id, 9644)
294317 self.assertEquals(txn.events, events)
295318 self.assertEquals(txn.service, service)
296319
297 @defer.inlineCallbacks
298 def test_complete_appservice_txn_first_txn(self):
320 def test_complete_appservice_txn_first_txn(
321 self,
322 ) -> None:
299323 service = Mock(id=self.as_list[0]["id"])
300324 events = [Mock(event_id="e1"), Mock(event_id="e2")]
301325 txn_id = 1
302326
303 yield self._insert_txn(service.id, txn_id, events)
304 yield defer.ensureDeferred(
327 self.get_success(self._insert_txn(service.id, txn_id, events))
328 self.get_success(
305329 self.store.complete_appservice_txn(txn_id=txn_id, service=service)
306330 )
307331
308 res = yield self.db_pool.runQuery(
309 self.engine.convert_param_style(
310 "SELECT last_txn FROM application_services_state WHERE as_id=?"
311 ),
312 (service.id,),
332 res = self.get_success(
333 self.db_pool.runQuery(
334 self.engine.convert_param_style(
335 "SELECT last_txn FROM application_services_state WHERE as_id=?"
336 ),
337 (service.id,),
338 )
313339 )
314340 self.assertEquals(1, len(res))
315341 self.assertEquals(txn_id, res[0][0])
316342
317 res = yield self.db_pool.runQuery(
318 self.engine.convert_param_style(
319 "SELECT * FROM application_services_txns WHERE txn_id=?"
320 ),
321 (txn_id,),
343 res = self.get_success(
344 self.db_pool.runQuery(
345 self.engine.convert_param_style(
346 "SELECT * FROM application_services_txns WHERE txn_id=?"
347 ),
348 (txn_id,),
349 )
322350 )
323351 self.assertEquals(0, len(res))
324352
325 @defer.inlineCallbacks
326 def test_complete_appservice_txn_existing_in_state_table(self):
353 def test_complete_appservice_txn_existing_in_state_table(
354 self,
355 ) -> None:
327356 service = Mock(id=self.as_list[0]["id"])
328357 events = [Mock(event_id="e1"), Mock(event_id="e2")]
329358 txn_id = 5
330 yield self._set_last_txn(service.id, 4)
331 yield self._insert_txn(service.id, txn_id, events)
332 yield defer.ensureDeferred(
359 self.get_success(self._set_last_txn(service.id, 4))
360 self.get_success(self._insert_txn(service.id, txn_id, events))
361 self.get_success(
333362 self.store.complete_appservice_txn(txn_id=txn_id, service=service)
334363 )
335364
336 res = yield self.db_pool.runQuery(
337 self.engine.convert_param_style(
338 "SELECT last_txn, state FROM application_services_state WHERE as_id=?"
339 ),
340 (service.id,),
365 res = self.get_success(
366 self.db_pool.runQuery(
367 self.engine.convert_param_style(
368 "SELECT last_txn, state FROM application_services_state WHERE as_id=?"
369 ),
370 (service.id,),
371 )
341372 )
342373 self.assertEquals(1, len(res))
343374 self.assertEquals(txn_id, res[0][0])
344 self.assertEquals(ApplicationServiceState.UP, res[0][1])
345
346 res = yield self.db_pool.runQuery(
347 self.engine.convert_param_style(
348 "SELECT * FROM application_services_txns WHERE txn_id=?"
349 ),
350 (txn_id,),
375 self.assertEquals(ApplicationServiceState.UP.value, res[0][1])
376
377 res = self.get_success(
378 self.db_pool.runQuery(
379 self.engine.convert_param_style(
380 "SELECT * FROM application_services_txns WHERE txn_id=?"
381 ),
382 (txn_id,),
383 )
351384 )
352385 self.assertEquals(0, len(res))
353386
354 @defer.inlineCallbacks
355 def test_get_oldest_unsent_txn_none(self):
356 service = Mock(id=self.as_list[0]["id"])
357
358 txn = yield defer.ensureDeferred(self.store.get_oldest_unsent_txn(service))
387 def test_get_oldest_unsent_txn_none(
388 self,
389 ) -> None:
390 service = Mock(id=self.as_list[0]["id"])
391
392 txn = self.get_success(self.store.get_oldest_unsent_txn(service))
359393 self.assertEquals(None, txn)
360394
361 @defer.inlineCallbacks
362 def test_get_oldest_unsent_txn(self):
395 def test_get_oldest_unsent_txn(self) -> None:
363396 service = Mock(id=self.as_list[0]["id"])
364397 events = [Mock(event_id="e1"), Mock(event_id="e2")]
365398 other_events = [Mock(event_id="e5"), Mock(event_id="e6")]
366399
367400 # we aren't testing store._base stuff here, so mock this out
368 self.store.get_events_as_list = Mock(return_value=make_awaitable(events))
369
370 yield self._insert_txn(self.as_list[1]["id"], 9, other_events)
371 yield self._insert_txn(service.id, 10, events)
372 yield self._insert_txn(service.id, 11, other_events)
373 yield self._insert_txn(service.id, 12, other_events)
374
375 txn = yield defer.ensureDeferred(self.store.get_oldest_unsent_txn(service))
401 # (ignore needed because Mypy won't allow us to assign to a method otherwise)
402 self.store.get_events_as_list = Mock(return_value=make_awaitable(events)) # type: ignore[assignment]
403
404 self.get_success(self._insert_txn(self.as_list[1]["id"], 9, other_events))
405 self.get_success(self._insert_txn(service.id, 10, events))
406 self.get_success(self._insert_txn(service.id, 11, other_events))
407 self.get_success(self._insert_txn(service.id, 12, other_events))
408
409 txn = self.get_success(self.store.get_oldest_unsent_txn(service))
376410 self.assertEquals(service, txn.service)
377411 self.assertEquals(10, txn.id)
378412 self.assertEquals(events, txn.events)
379413
380 @defer.inlineCallbacks
381 def test_get_appservices_by_state_single(self):
382 yield self._set_state(self.as_list[0]["id"], ApplicationServiceState.DOWN)
383 yield self._set_state(self.as_list[1]["id"], ApplicationServiceState.UP)
384
385 services = yield defer.ensureDeferred(
414 def test_get_appservices_by_state_single(
415 self,
416 ) -> None:
417 self.get_success(
418 self._set_state(self.as_list[0]["id"], ApplicationServiceState.DOWN)
419 )
420 self.get_success(
421 self._set_state(self.as_list[1]["id"], ApplicationServiceState.UP)
422 )
423
424 services = self.get_success(
386425 self.store.get_appservices_by_state(ApplicationServiceState.DOWN)
387426 )
388427 self.assertEquals(1, len(services))
389428 self.assertEquals(self.as_list[0]["id"], services[0].id)
390429
391 @defer.inlineCallbacks
392 def test_get_appservices_by_state_multiple(self):
393 yield self._set_state(self.as_list[0]["id"], ApplicationServiceState.DOWN)
394 yield self._set_state(self.as_list[1]["id"], ApplicationServiceState.UP)
395 yield self._set_state(self.as_list[2]["id"], ApplicationServiceState.DOWN)
396 yield self._set_state(self.as_list[3]["id"], ApplicationServiceState.UP)
397
398 services = yield defer.ensureDeferred(
430 def test_get_appservices_by_state_multiple(
431 self,
432 ) -> None:
433 self.get_success(
434 self._set_state(self.as_list[0]["id"], ApplicationServiceState.DOWN)
435 )
436 self.get_success(
437 self._set_state(self.as_list[1]["id"], ApplicationServiceState.UP)
438 )
439 self.get_success(
440 self._set_state(self.as_list[2]["id"], ApplicationServiceState.DOWN)
441 )
442 self.get_success(
443 self._set_state(self.as_list[3]["id"], ApplicationServiceState.UP)
444 )
445
446 services = self.get_success(
399447 self.store.get_appservices_by_state(ApplicationServiceState.DOWN)
400448 )
401449 self.assertEquals(2, len(services))
406454
407455
408456 class ApplicationServiceStoreTypeStreamIds(unittest.HomeserverTestCase):
409 def make_homeserver(self, reactor, clock):
410 hs = self.setup_test_homeserver()
411 return hs
412
413 def prepare(self, hs, reactor, clock):
457 def prepare(
458 self, reactor: MemoryReactor, clock: Clock, homeserver: HomeServer
459 ) -> None:
414460 self.service = Mock(id="foo")
415461 self.store = self.hs.get_datastore()
416 self.get_success(self.store.set_appservice_state(self.service, "up"))
417
418 def test_get_type_stream_id_for_appservice_no_value(self):
462 self.get_success(
463 self.store.set_appservice_state(self.service, ApplicationServiceState.UP)
464 )
465
466 def test_get_type_stream_id_for_appservice_no_value(self) -> None:
419467 value = self.get_success(
420468 self.store.get_type_stream_id_for_appservice(self.service, "read_receipt")
421469 )
426474 )
427475 self.assertEquals(value, 0)
428476
429 def test_get_type_stream_id_for_appservice_invalid_type(self):
477 def test_get_type_stream_id_for_appservice_invalid_type(self) -> None:
430478 self.get_failure(
431479 self.store.get_type_stream_id_for_appservice(self.service, "foobar"),
432480 ValueError,
433481 )
434482
435 def test_set_type_stream_id_for_appservice(self):
483 def test_set_type_stream_id_for_appservice(self) -> None:
436484 read_receipt_value = 1024
437485 self.get_success(
438486 self.store.set_type_stream_id_for_appservice(
454502 )
455503 self.assertEqual(result, read_receipt_value)
456504
457 def test_set_type_stream_id_for_appservice_invalid_type(self):
505 def test_set_type_stream_id_for_appservice_invalid_type(self) -> None:
458506 self.get_failure(
459507 self.store.set_type_stream_id_for_appservice(self.service, "foobar", 1024),
460508 ValueError,
463511
464512 # required for ApplicationServiceTransactionStoreTestCase tests
465513 class TestTransactionStore(ApplicationServiceTransactionStore, ApplicationServiceStore):
466 def __init__(self, database: DatabasePool, db_conn, hs):
514 def __init__(self, database: DatabasePool, db_conn, hs) -> None:
467515 super().__init__(database, db_conn, hs)
468516
469517
470 class ApplicationServiceStoreConfigTestCase(unittest.TestCase):
471 def _write_config(self, suffix, **kwargs):
518 class ApplicationServiceStoreConfigTestCase(unittest.HomeserverTestCase):
519 def _write_config(self, suffix, **kwargs) -> str:
472520 vals = {
473521 "id": "id" + suffix,
474522 "url": "url" + suffix,
484532 f.write(yaml.dump(vals))
485533 return path
486534
487 @defer.inlineCallbacks
488 def test_unique_works(self):
535 def test_unique_works(self) -> None:
489536 f1 = self._write_config(suffix="1")
490537 f2 = self._write_config(suffix="2")
491538
492 hs = yield setup_test_homeserver(
493 self.addCleanup, federation_sender=Mock(), federation_client=Mock()
494 )
495
496 hs.config.appservice.app_service_config_files = [f1, f2]
497 hs.config.caches.event_cache_size = 1
498
499 database = hs.get_datastores().databases[0]
539 self.hs.config.appservice.app_service_config_files = [f1, f2]
540 self.hs.config.caches.event_cache_size = 1
541
542 database = self.hs.get_datastores().databases[0]
500543 ApplicationServiceStore(
501 database, make_conn(database._database_config, database.engine, "test"), hs
502 )
503
504 @defer.inlineCallbacks
505 def test_duplicate_ids(self):
544 database,
545 make_conn(database._database_config, database.engine, "test"),
546 self.hs,
547 )
548
549 def test_duplicate_ids(self) -> None:
506550 f1 = self._write_config(id="id", suffix="1")
507551 f2 = self._write_config(id="id", suffix="2")
508552
509 hs = yield setup_test_homeserver(
510 self.addCleanup, federation_sender=Mock(), federation_client=Mock()
511 )
512
513 hs.config.appservice.app_service_config_files = [f1, f2]
514 hs.config.caches.event_cache_size = 1
553 self.hs.config.appservice.app_service_config_files = [f1, f2]
554 self.hs.config.caches.event_cache_size = 1
515555
516556 with self.assertRaises(ConfigError) as cm:
517 database = hs.get_datastores().databases[0]
557 database = self.hs.get_datastores().databases[0]
518558 ApplicationServiceStore(
519559 database,
520560 make_conn(database._database_config, database.engine, "test"),
521 hs,
561 self.hs,
522562 )
523563
524564 e = cm.exception
526566 self.assertIn(f2, str(e))
527567 self.assertIn("id", str(e))
528568
529 @defer.inlineCallbacks
530 def test_duplicate_as_tokens(self):
569 def test_duplicate_as_tokens(self) -> None:
531570 f1 = self._write_config(as_token="as_token", suffix="1")
532571 f2 = self._write_config(as_token="as_token", suffix="2")
533572
534 hs = yield setup_test_homeserver(
535 self.addCleanup, federation_sender=Mock(), federation_client=Mock()
536 )
537
538 hs.config.appservice.app_service_config_files = [f1, f2]
539 hs.config.caches.event_cache_size = 1
573 self.hs.config.appservice.app_service_config_files = [f1, f2]
574 self.hs.config.caches.event_cache_size = 1
540575
541576 with self.assertRaises(ConfigError) as cm:
542 database = hs.get_datastores().databases[0]
577 database = self.hs.get_datastores().databases[0]
543578 ApplicationServiceStore(
544579 database,
545580 make_conn(database._database_config, database.engine, "test"),
546 hs,
581 self.hs,
547582 )
548583
549584 e = cm.exception
0 from unittest.mock import Mock
0 # Copyright 2021 The Matrix.org Foundation C.I.C.
1 #
2 # Licensed under the Apache License, Version 2.0 (the "License");
3 # you may not use this file except in compliance with the License.
4 # You may obtain a copy of the License at
5 #
6 # http://www.apache.org/licenses/LICENSE-2.0
7 #
8 # Unless required by applicable law or agreed to in writing, software
9 # distributed under the License is distributed on an "AS IS" BASIS,
10 # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
11 # See the License for the specific language governing permissions and
12 # limitations under the License.
13
14 # Use backported mock for AsyncMock support on Python 3.6.
15 from mock import Mock
16
17 from twisted.internet.defer import Deferred, ensureDeferred
118
219 from synapse.storage.background_updates import BackgroundUpdater
320
421 from tests import unittest
22 from tests.test_utils import make_awaitable
523
624
725 class BackgroundUpdateTestCase(unittest.HomeserverTestCase):
1937
2038 def test_do_background_update(self):
2139 # the time we claim it takes to update one item when running the update
22 duration_ms = 4200
40 duration_ms = 10
2341
2442 # the target runtime for each bg update
25 target_background_update_duration_ms = 5000000
43 target_background_update_duration_ms = 100
2644
2745 store = self.hs.get_datastore()
2846 self.get_success(
4765 self.update_handler.side_effect = update
4866 self.update_handler.reset_mock()
4967 res = self.get_success(
50 self.updates.do_next_background_update(
51 target_background_update_duration_ms
52 ),
53 by=0.1,
68 self.updates.do_next_background_update(False),
69 by=0.01,
5470 )
5571 self.assertFalse(res)
5672
7389
7490 self.update_handler.side_effect = update
7591 self.update_handler.reset_mock()
76 result = self.get_success(
77 self.updates.do_next_background_update(target_background_update_duration_ms)
78 )
92 result = self.get_success(self.updates.do_next_background_update(False))
7993 self.assertFalse(result)
8094 self.update_handler.assert_called_once()
8195
8296 # third step: we don't expect to be called any more
8397 self.update_handler.reset_mock()
84 result = self.get_success(
85 self.updates.do_next_background_update(target_background_update_duration_ms)
86 )
98 result = self.get_success(self.updates.do_next_background_update(False))
8799 self.assertTrue(result)
88100 self.assertFalse(self.update_handler.called)
101
102
103 class BackgroundUpdateControllerTestCase(unittest.HomeserverTestCase):
104 def prepare(self, reactor, clock, homeserver):
105 self.updates: BackgroundUpdater = self.hs.get_datastore().db_pool.updates
106 # the base test class should have run the real bg updates for us
107 self.assertTrue(
108 self.get_success(self.updates.has_completed_background_updates())
109 )
110
111 self.update_deferred = Deferred()
112 self.update_handler = Mock(return_value=self.update_deferred)
113 self.updates.register_background_update_handler(
114 "test_update", self.update_handler
115 )
116
117 # Mock out the AsyncContextManager
118 self._update_ctx_manager = Mock(spec=["__aenter__", "__aexit__"])
119 self._update_ctx_manager.__aenter__ = Mock(
120 return_value=make_awaitable(None),
121 )
122 self._update_ctx_manager.__aexit__ = Mock(return_value=make_awaitable(None))
123
124 # Mock out the `update_handler` callback
125 self._on_update = Mock(return_value=self._update_ctx_manager)
126
127 # Define a default batch size value that's not the same as the internal default
128 # value (100).
129 self._default_batch_size = 500
130
131 # Register the callbacks with more mocks
132 self.hs.get_module_api().register_background_update_controller_callbacks(
133 on_update=self._on_update,
134 min_batch_size=Mock(return_value=make_awaitable(self._default_batch_size)),
135 default_batch_size=Mock(
136 return_value=make_awaitable(self._default_batch_size),
137 ),
138 )
139
140 def test_controller(self):
141 store = self.hs.get_datastore()
142 self.get_success(
143 store.db_pool.simple_insert(
144 "background_updates",
145 values={"update_name": "test_update", "progress_json": "{}"},
146 )
147 )
148
149 # Set the return value for the context manager.
150 enter_defer = Deferred()
151 self._update_ctx_manager.__aenter__ = Mock(return_value=enter_defer)
152
153 # Start the background update.
154 do_update_d = ensureDeferred(self.updates.do_next_background_update(True))
155
156 self.pump()
157
158 # `run_update` should have been called, but the update handler won't be
159 # called until the `enter_defer` (returned by `__aenter__`) is resolved.
160 self._on_update.assert_called_once_with(
161 "test_update",
162 "master",
163 False,
164 )
165 self.assertFalse(do_update_d.called)
166 self.assertFalse(self.update_deferred.called)
167
168 # Resolving the `enter_defer` should call the update handler, which then
169 # blocks.
170 enter_defer.callback(100)
171 self.pump()
172 self.update_handler.assert_called_once_with({}, self._default_batch_size)
173 self.assertFalse(self.update_deferred.called)
174 self._update_ctx_manager.__aexit__.assert_not_called()
175
176 # Resolving the update handler deferred should cause the
177 # `do_next_background_update` to finish and return
178 self.update_deferred.callback(100)
179 self.pump()
180 self._update_ctx_manager.__aexit__.assert_called()
181 self.get_success(do_update_d)
663663 ):
664664 iterations += 1
665665 self.get_success(
666 self.store.db_pool.updates.do_next_background_update(100), by=0.1
666 self.store.db_pool.updates.do_next_background_update(False), by=0.1
667667 )
668668
669669 # Ensure that we did actually take multiple iterations to process the
722722 ):
723723 iterations += 1
724724 self.get_success(
725 self.store.db_pool.updates.do_next_background_update(100), by=0.1
725 self.store.db_pool.updates.do_next_background_update(False), by=0.1
726726 )
727727
728728 # Ensure that we did actually take multiple iterations to process the
1212 # limitations under the License.
1313
1414
15 from twisted.internet import defer
16
1715 from synapse.types import UserID
1816
1917 from tests import unittest
20 from tests.utils import setup_test_homeserver
2118
2219
23 class DataStoreTestCase(unittest.TestCase):
24 @defer.inlineCallbacks
25 def setUp(self):
26 hs = yield setup_test_homeserver(self.addCleanup)
20 class DataStoreTestCase(unittest.HomeserverTestCase):
21 def setUp(self) -> None:
22 super(DataStoreTestCase, self).setUp()
2723
28 self.store = hs.get_datastore()
24 self.store = self.hs.get_datastore()
2925
3026 self.user = UserID.from_string("@abcde:test")
3127 self.displayname = "Frank"
3228
33 @defer.inlineCallbacks
34 def test_get_users_paginate(self):
35 yield defer.ensureDeferred(
36 self.store.register_user(self.user.to_string(), "pass")
37 )
38 yield defer.ensureDeferred(self.store.create_profile(self.user.localpart))
39 yield defer.ensureDeferred(
29 def test_get_users_paginate(self) -> None:
30 self.get_success(self.store.register_user(self.user.to_string(), "pass"))
31 self.get_success(self.store.create_profile(self.user.localpart))
32 self.get_success(
4033 self.store.set_profile_displayname(self.user.localpart, self.displayname)
4134 )
4235
43 users, total = yield defer.ensureDeferred(
36 users, total = self.get_success(
4437 self.store.get_users_paginate(0, 10, name="bc", guests=False)
4538 )
4639
4740 self.assertEquals(1, total)
4841 self.assertEquals(self.displayname, users.pop()["displayname"])
4942
50 users, total = yield defer.ensureDeferred(
43 users, total = self.get_success(
5144 self.store.get_users_paginate(0, 10, name="BC", guests=False)
5245 )
5346
2222 from synapse.rest.client import login, register, room
2323 from synapse.server import HomeServer
2424 from synapse.storage import DataStore
25 from synapse.storage.background_updates import _BackgroundUpdateHandler
2526 from synapse.storage.roommember import ProfileInfo
2627 from synapse.util import Clock
2728
390391
391392 with mock.patch.dict(
392393 self.store.db_pool.updates._background_update_handlers,
393 populate_user_directory_process_users=mocked_process_users,
394 populate_user_directory_process_users=_BackgroundUpdateHandler(
395 mocked_process_users,
396 ),
394397 ):
395398 self._purge_and_rebuild_user_dir()
396399
1212 # limitations under the License.
1313 import logging
1414 from typing import Optional
15 from unittest.mock import Mock
16
17 from twisted.internet import defer
18 from twisted.internet.defer import succeed
1915
2016 from synapse.api.room_versions import RoomVersions
21 from synapse.events import FrozenEvent
17 from synapse.events import EventBase
18 from synapse.types import JsonDict
2219 from synapse.visibility import filter_events_for_server
2320
24 import tests.unittest
25 from tests.utils import create_room, setup_test_homeserver
21 from tests import unittest
22 from tests.utils import create_room
2623
2724 logger = logging.getLogger(__name__)
2825
2926 TEST_ROOM_ID = "!TEST:ROOM"
3027
3128
32 class FilterEventsForServerTestCase(tests.unittest.TestCase):
33 @defer.inlineCallbacks
34 def setUp(self):
35 self.hs = yield setup_test_homeserver(self.addCleanup)
29 class FilterEventsForServerTestCase(unittest.HomeserverTestCase):
30 def setUp(self) -> None:
31 super(FilterEventsForServerTestCase, self).setUp()
3632 self.event_creation_handler = self.hs.get_event_creation_handler()
3733 self.event_builder_factory = self.hs.get_event_builder_factory()
3834 self.storage = self.hs.get_storage()
3935
40 yield defer.ensureDeferred(create_room(self.hs, TEST_ROOM_ID, "@someone:ROOM"))
36 self.get_success(create_room(self.hs, TEST_ROOM_ID, "@someone:ROOM"))
4137
42 @defer.inlineCallbacks
43 def test_filtering(self):
38 def test_filtering(self) -> None:
4439 #
4540 # The events to be filtered consist of 10 membership events (it doesn't
4641 # really matter if they are joins or leaves, so let's make them joins).
5045 #
5146
5247 # before we do that, we persist some other events to act as state.
53 yield self.inject_visibility("@admin:hs", "joined")
48 self.get_success(self._inject_visibility("@admin:hs", "joined"))
5449 for i in range(0, 10):
55 yield self.inject_room_member("@resident%i:hs" % i)
50 self.get_success(self._inject_room_member("@resident%i:hs" % i))
5651
5752 events_to_filter = []
5853
5954 for i in range(0, 10):
6055 user = "@user%i:%s" % (i, "test_server" if i == 5 else "other_server")
61 evt = yield self.inject_room_member(user, extra_content={"a": "b"})
56 evt = self.get_success(
57 self._inject_room_member(user, extra_content={"a": "b"})
58 )
6259 events_to_filter.append(evt)
6360
64 filtered = yield defer.ensureDeferred(
61 filtered = self.get_success(
6562 filter_events_for_server(self.storage, "test_server", events_to_filter)
6663 )
6764
7471 self.assertEqual(events_to_filter[i].event_id, filtered[i].event_id)
7572 self.assertEqual(filtered[i].content["a"], "b")
7673
77 @defer.inlineCallbacks
78 def test_erased_user(self):
74 def test_erased_user(self) -> None:
7975 # 4 message events, from erased and unerased users, with a membership
8076 # change in the middle of them.
8177 events_to_filter = []
8278
83 evt = yield self.inject_message("@unerased:local_hs")
79 evt = self.get_success(self._inject_message("@unerased:local_hs"))
8480 events_to_filter.append(evt)
8581
86 evt = yield self.inject_message("@erased:local_hs")
82 evt = self.get_success(self._inject_message("@erased:local_hs"))
8783 events_to_filter.append(evt)
8884
89 evt = yield self.inject_room_member("@joiner:remote_hs")
85 evt = self.get_success(self._inject_room_member("@joiner:remote_hs"))
9086 events_to_filter.append(evt)
9187
92 evt = yield self.inject_message("@unerased:local_hs")
88 evt = self.get_success(self._inject_message("@unerased:local_hs"))
9389 events_to_filter.append(evt)
9490
95 evt = yield self.inject_message("@erased:local_hs")
91 evt = self.get_success(self._inject_message("@erased:local_hs"))
9692 events_to_filter.append(evt)
9793
9894 # the erasey user gets erased
99 yield defer.ensureDeferred(
100 self.hs.get_datastore().mark_user_erased("@erased:local_hs")
101 )
95 self.get_success(self.hs.get_datastore().mark_user_erased("@erased:local_hs"))
10296
10397 # ... and the filtering happens.
104 filtered = yield defer.ensureDeferred(
98 filtered = self.get_success(
10599 filter_events_for_server(self.storage, "test_server", events_to_filter)
106100 )
107101
122116 for i in (1, 4):
123117 self.assertNotIn("body", filtered[i].content)
124118
125 @defer.inlineCallbacks
126 def inject_visibility(self, user_id, visibility):
119 def _inject_visibility(self, user_id: str, visibility: str) -> EventBase:
127120 content = {"history_visibility": visibility}
128121 builder = self.event_builder_factory.for_room_version(
129122 RoomVersions.V1,
136129 },
137130 )
138131
139 event, context = yield defer.ensureDeferred(
132 event, context = self.get_success(
140133 self.event_creation_handler.create_new_client_event(builder)
141134 )
142 yield defer.ensureDeferred(
143 self.storage.persistence.persist_event(event, context)
144 )
135 self.get_success(self.storage.persistence.persist_event(event, context))
145136 return event
146137
147 @defer.inlineCallbacks
148 def inject_room_member(
149 self, user_id, membership="join", extra_content: Optional[dict] = None
150 ):
138 def _inject_room_member(
139 self,
140 user_id: str,
141 membership: str = "join",
142 extra_content: Optional[JsonDict] = None,
143 ) -> EventBase:
151144 content = {"membership": membership}
152145 content.update(extra_content or {})
153146 builder = self.event_builder_factory.for_room_version(
161154 },
162155 )
163156
164 event, context = yield defer.ensureDeferred(
157 event, context = self.get_success(
165158 self.event_creation_handler.create_new_client_event(builder)
166159 )
167160
168 yield defer.ensureDeferred(
169 self.storage.persistence.persist_event(event, context)
170 )
161 self.get_success(self.storage.persistence.persist_event(event, context))
171162 return event
172163
173 @defer.inlineCallbacks
174 def inject_message(self, user_id, content=None):
164 def _inject_message(
165 self, user_id: str, content: Optional[JsonDict] = None
166 ) -> EventBase:
175167 if content is None:
176168 content = {"body": "testytest", "msgtype": "m.text"}
177169 builder = self.event_builder_factory.for_room_version(
184176 },
185177 )
186178
187 event, context = yield defer.ensureDeferred(
179 event, context = self.get_success(
188180 self.event_creation_handler.create_new_client_event(builder)
189181 )
190182
191 yield defer.ensureDeferred(
192 self.storage.persistence.persist_event(event, context)
193 )
183 self.get_success(self.storage.persistence.persist_event(event, context))
194184 return event
195
196 @defer.inlineCallbacks
197 def test_large_room(self):
198 # see what happens when we have a large room with hundreds of thousands
199 # of membership events
200
201 # As above, the events to be filtered consist of 10 membership events,
202 # where one of them is for a user on the server we are filtering for.
203
204 import cProfile
205 import pstats
206 import time
207
208 # we stub out the store, because building up all that state the normal
209 # way is very slow.
210 test_store = _TestStore()
211
212 # our initial state is 100000 membership events and one
213 # history_visibility event.
214 room_state = []
215
216 history_visibility_evt = FrozenEvent(
217 {
218 "event_id": "$history_vis",
219 "type": "m.room.history_visibility",
220 "sender": "@resident_user_0:test.com",
221 "state_key": "",
222 "room_id": TEST_ROOM_ID,
223 "content": {"history_visibility": "joined"},
224 }
225 )
226 room_state.append(history_visibility_evt)
227 test_store.add_event(history_visibility_evt)
228
229 for i in range(0, 100000):
230 user = "@resident_user_%i:test.com" % (i,)
231 evt = FrozenEvent(
232 {
233 "event_id": "$res_event_%i" % (i,),
234 "type": "m.room.member",
235 "state_key": user,
236 "sender": user,
237 "room_id": TEST_ROOM_ID,
238 "content": {"membership": "join", "extra": "zzz,"},
239 }
240 )
241 room_state.append(evt)
242 test_store.add_event(evt)
243
244 events_to_filter = []
245 for i in range(0, 10):
246 user = "@user%i:%s" % (i, "test_server" if i == 5 else "other_server")
247 evt = FrozenEvent(
248 {
249 "event_id": "$evt%i" % (i,),
250 "type": "m.room.member",
251 "state_key": user,
252 "sender": user,
253 "room_id": TEST_ROOM_ID,
254 "content": {"membership": "join", "extra": "zzz"},
255 }
256 )
257 events_to_filter.append(evt)
258 room_state.append(evt)
259
260 test_store.add_event(evt)
261 test_store.set_state_ids_for_event(
262 evt, {(e.type, e.state_key): e.event_id for e in room_state}
263 )
264
265 pr = cProfile.Profile()
266 pr.enable()
267
268 logger.info("Starting filtering")
269 start = time.time()
270
271 storage = Mock()
272 storage.main = test_store
273 storage.state = test_store
274
275 filtered = yield defer.ensureDeferred(
276 filter_events_for_server(test_store, "test_server", events_to_filter)
277 )
278 logger.info("Filtering took %f seconds", time.time() - start)
279
280 pr.disable()
281 with open("filter_events_for_server.profile", "w+") as f:
282 ps = pstats.Stats(pr, stream=f).sort_stats("cumulative")
283 ps.print_stats()
284
285 # the result should be 5 redacted events, and 5 unredacted events.
286 for i in range(0, 5):
287 self.assertEqual(events_to_filter[i].event_id, filtered[i].event_id)
288 self.assertNotIn("extra", filtered[i].content)
289
290 for i in range(5, 10):
291 self.assertEqual(events_to_filter[i].event_id, filtered[i].event_id)
292 self.assertEqual(filtered[i].content["extra"], "zzz")
293
294 test_large_room.skip = "Disabled by default because it's slow"
295
296
297 class _TestStore:
298 """Implements a few methods of the DataStore, so that we can test
299 filter_events_for_server
300
301 """
302
303 def __init__(self):
304 # data for get_events: a map from event_id to event
305 self.events = {}
306
307 # data for get_state_ids_for_events mock: a map from event_id to
308 # a map from (type_state_key) -> event_id for the state at that
309 # event
310 self.state_ids_for_events = {}
311
312 def add_event(self, event):
313 self.events[event.event_id] = event
314
315 def set_state_ids_for_event(self, event, state):
316 self.state_ids_for_events[event.event_id] = state
317
318 def get_state_ids_for_events(self, events, types):
319 res = {}
320 include_memberships = False
321 for (type, state_key) in types:
322 if type == "m.room.history_visibility":
323 continue
324 if type != "m.room.member" or state_key is not None:
325 raise RuntimeError(
326 "Unimplemented: get_state_ids with type (%s, %s)"
327 % (type, state_key)
328 )
329 include_memberships = True
330
331 if include_memberships:
332 for event_id in events:
333 res[event_id] = self.state_ids_for_events[event_id]
334
335 else:
336 k = ("m.room.history_visibility", "")
337 for event_id in events:
338 hve = self.state_ids_for_events[event_id][k]
339 res[event_id] = {k: hve}
340
341 return succeed(res)
342
343 def get_events(self, events):
344 return succeed({event_id: self.events[event_id] for event_id in events})
345
346 def are_users_erased(self, users):
347 return succeed({u: False for u in users})
330330 time.sleep(0.01)
331331
332332 def wait_for_background_updates(self) -> None:
333 """
334 Block until all background database updates have completed.
335
336 Note that callers must ensure that's a store property created on the
333 """Block until all background database updates have completed.
334
335 Note that callers must ensure there's a store property created on the
337336 testcase.
338337 """
339338 while not self.get_success(
340339 self.store.db_pool.updates.has_completed_background_updates()
341340 ):
342341 self.get_success(
343 self.store.db_pool.updates.do_next_background_update(100), by=0.1
342 self.store.db_pool.updates.do_next_background_update(False), by=0.1
344343 )
345344
346345 def make_homeserver(self, reactor, clock):
499498
500499 async def run_bg_updates():
501500 with LoggingContext("run_bg_updates"):
502 while not await stor.db_pool.updates.has_completed_background_updates():
503 await stor.db_pool.updates.do_next_background_update(1)
501 self.get_success(stor.db_pool.updates.run_background_updates(False))
504502
505503 hs = setup_test_homeserver(self.addCleanup, *args, **kwargs)
506504 stor = hs.get_datastore()
1212 # limitations under the License.
1313
1414
15 from typing import List
1516 from unittest.mock import Mock
1617
1718 from synapse.util.caches.lrucache import LruCache, setup_expire_lru_cache_entries
260261 self.assertEquals(cache["key4"], [4])
261262 self.assertEquals(cache["key5"], [5, 6])
262263
264 def test_zero_size_drop_from_cache(self) -> None:
265 """Test that `drop_from_cache` works correctly with 0-sized entries."""
266 cache: LruCache[str, List[int]] = LruCache(5, size_callback=lambda x: 0)
267 cache["key1"] = []
268
269 self.assertEqual(len(cache), 0)
270 cache.cache["key1"].drop_from_cache()
271 self.assertIsNone(
272 cache.pop("key1"), "Cache entry should have been evicted but wasn't"
273 )
274
263275
264276 class TimeEvictionTestCase(unittest.HomeserverTestCase):
265277 """Test that time based eviction works correctly."""