Codebase list matrix-synapse / bb2febe
New upstream version 1.22.1 Andrej Shadura 3 years ago
304 changed file(s) with 8750 addition(s) and 5082 deletion(s). Raw diff Collapse all Expand all
00 # This file serves as a blacklist for SyTest tests that we expect will fail in
11 # Synapse when run under worker mode. For more details, see sytest-blacklist.
22
3 Message history can be paginated
4
53 Can re-join room if re-invited
6
7 The only membership state included in an initial sync is for all the senders in the timeline
8
9 Local device key changes get to remote servers
10
11 If remote user leaves room we no longer receive device updates
12
13 Forgotten room messages cannot be paginated
14
15 Inbound federation can get public room list
16
17 Members from the gap are included in gappy incr LL sync
18
19 Leaves are present in non-gapped incremental syncs
20
21 Old leaves are present in gapped incremental syncs
22
23 User sees updates to presence from other users in the incremental sync.
24
25 Gapped incremental syncs include all state changes
26
27 Old members are included in gappy incr LL sync if they start speaking
284
295 # new failures as of https://github.com/matrix-org/sytest/pull/732
306 Device list doesn't change if remote server is down
31 Remote servers cannot set power levels in rooms without existing powerlevels
32 Remote servers should reject attempts by non-creators to set the power levels
337
348 # https://buildkite.com/matrix-dot-org/synapse/builds/6134#6f67bf47-e234-474d-80e8-c6e1868b15c5
359 Server correctly handles incoming m.device_list_update
36
37 # this fails reliably with a torture level of 100 due to https://github.com/matrix-org/synapse/issues/6536
38 Outbound federation requests missing prev_events and then asks for /state_ids and resolves the state
39
40 Can get rooms/{roomId}/members at a given point
0 version: 2
0 version: 2.1
11 jobs:
22 dockerhubuploadrelease:
3 machine: true
3 docker:
4 - image: docker:git
45 steps:
56 - checkout
6 - run: docker build -f docker/Dockerfile --label gitsha1=${CIRCLE_SHA1} -t matrixdotorg/synapse:${CIRCLE_TAG} .
7 - setup_remote_docker
8 - docker_prepare
79 - run: docker login --username $DOCKER_HUB_USERNAME --password $DOCKER_HUB_PASSWORD
8 - run: docker push matrixdotorg/synapse:${CIRCLE_TAG}
10 - docker_build:
11 tag: -t matrixdotorg/synapse:${CIRCLE_TAG}
12 platforms: linux/amd64
13 - docker_build:
14 tag: -t matrixdotorg/synapse:${CIRCLE_TAG}
15 platforms: linux/amd64,linux/arm/v7,linux/arm64
16
917 dockerhubuploadlatest:
10 machine: true
18 docker:
19 - image: docker:git
1120 steps:
1221 - checkout
13 - run: docker build -f docker/Dockerfile --label gitsha1=${CIRCLE_SHA1} -t matrixdotorg/synapse:latest .
22 - setup_remote_docker
23 - docker_prepare
1424 - run: docker login --username $DOCKER_HUB_USERNAME --password $DOCKER_HUB_PASSWORD
15 - run: docker push matrixdotorg/synapse:latest
25 - docker_build:
26 tag: -t matrixdotorg/synapse:latest
27 platforms: linux/amd64
28 - docker_build:
29 tag: -t matrixdotorg/synapse:latest
30 platforms: linux/amd64,linux/arm/v7,linux/arm64
1631
1732 workflows:
18 version: 2
1933 build:
2034 jobs:
2135 - dockerhubuploadrelease:
2842 filters:
2943 branches:
3044 only: master
45
46 commands:
47 docker_prepare:
48 description: Downloads the buildx cli plugin and enables multiarch images
49 parameters:
50 buildx_version:
51 type: string
52 default: "v0.4.1"
53 steps:
54 - run: apk add --no-cache curl
55 - run: mkdir -vp ~/.docker/cli-plugins/ ~/dockercache
56 - run: curl --silent -L "https://github.com/docker/buildx/releases/download/<< parameters.buildx_version >>/buildx-<< parameters.buildx_version >>.linux-amd64" > ~/.docker/cli-plugins/docker-buildx
57 - run: chmod a+x ~/.docker/cli-plugins/docker-buildx
58 # install qemu links in /proc/sys/fs/binfmt_misc on the docker instance running the circleci job
59 - run: docker run --rm --privileged multiarch/qemu-user-static --reset -p yes
60 # create a context named `builder` for the builds
61 - run: docker context create builder
62 # create a buildx builder using the new context, and set it as the default
63 - run: docker buildx create builder --use
64
65 docker_build:
66 description: Builds and pushed images to dockerhub using buildx
67 parameters:
68 platforms:
69 type: string
70 default: linux/amd64
71 tag:
72 type: string
73 steps:
74 - run: docker buildx build -f docker/Dockerfile --push --platform << parameters.platforms >> --label gitsha1=${CIRCLE_SHA1} << parameters.tag >> --progress=plain .
2020 /.python-version
2121 /*.signing.key
2222 /env/
23 /.venv*/
2324 /homeserver*.yaml
2425 /logs
2526 /media_store/
0 Synapse 1.22.1 (2020-10-30)
1 ===========================
2
3 Bugfixes
4 --------
5
6 - Fix a bug where an appservice may not be forwarded events for a room it was recently invited to. Broke in v1.22.0. ([\#8676](https://github.com/matrix-org/synapse/issues/8676))
7 - Fix `Object of type frozendict is not JSON serializable` exceptions when using third-party event rules. Broke in v1.22.0. ([\#8678](https://github.com/matrix-org/synapse/issues/8678))
8
9
10 Synapse 1.22.0 (2020-10-27)
11 ===========================
12
13 No significant changes.
14
15
16 Synapse 1.22.0rc2 (2020-10-26)
17 ==============================
18
19 Bugfixes
20 --------
21
22 - Fix bugs where ephemeral events were not sent to appservices. Broke in v1.22.0rc1. ([\#8648](https://github.com/matrix-org/synapse/issues/8648), [\#8656](https://github.com/matrix-org/synapse/issues/8656))
23 - Fix `user_daily_visits` table to not have duplicate rows per user/device due to multiple user agents. Broke in v1.22.0rc1. ([\#8654](https://github.com/matrix-org/synapse/issues/8654))
24
25 Synapse 1.22.0rc1 (2020-10-22)
26 ==============================
27
28 Features
29 --------
30
31 - Add a configuration option for always using the "userinfo endpoint" for OpenID Connect. This fixes support for some identity providers, e.g. GitLab. Contributed by Benjamin Koch. ([\#7658](https://github.com/matrix-org/synapse/issues/7658))
32 - Add ability for `ThirdPartyEventRules` modules to query and manipulate whether a room is in the public rooms directory. ([\#8292](https://github.com/matrix-org/synapse/issues/8292), [\#8467](https://github.com/matrix-org/synapse/issues/8467))
33 - Add support for olm fallback keys ([MSC2732](https://github.com/matrix-org/matrix-doc/pull/2732)). ([\#8312](https://github.com/matrix-org/synapse/issues/8312), [\#8501](https://github.com/matrix-org/synapse/issues/8501))
34 - Add support for running background tasks in a separate worker process. ([\#8369](https://github.com/matrix-org/synapse/issues/8369), [\#8458](https://github.com/matrix-org/synapse/issues/8458), [\#8489](https://github.com/matrix-org/synapse/issues/8489), [\#8513](https://github.com/matrix-org/synapse/issues/8513), [\#8544](https://github.com/matrix-org/synapse/issues/8544), [\#8599](https://github.com/matrix-org/synapse/issues/8599))
35 - Add support for device dehydration ([MSC2697](https://github.com/matrix-org/matrix-doc/pull/2697)). ([\#8380](https://github.com/matrix-org/synapse/issues/8380))
36 - Add support for [MSC2409](https://github.com/matrix-org/matrix-doc/pull/2409), which allows sending typing, read receipts, and presence events to appservices. ([\#8437](https://github.com/matrix-org/synapse/issues/8437), [\#8590](https://github.com/matrix-org/synapse/issues/8590))
37 - Change default room version to "6", per [MSC2788](https://github.com/matrix-org/matrix-doc/pull/2788). ([\#8461](https://github.com/matrix-org/synapse/issues/8461))
38 - Add the ability to send non-membership events into a room via the `ModuleApi`. ([\#8479](https://github.com/matrix-org/synapse/issues/8479))
39 - Increase default upload size limit from 10M to 50M. Contributed by @Akkowicz. ([\#8502](https://github.com/matrix-org/synapse/issues/8502))
40 - Add support for modifying event content in `ThirdPartyRules` modules. ([\#8535](https://github.com/matrix-org/synapse/issues/8535), [\#8564](https://github.com/matrix-org/synapse/issues/8564))
41
42
43 Bugfixes
44 --------
45
46 - Fix a longstanding bug where invalid ignored users in account data could break clients. ([\#8454](https://github.com/matrix-org/synapse/issues/8454))
47 - Fix a bug where backfilling a room with an event that was missing the `redacts` field would break. ([\#8457](https://github.com/matrix-org/synapse/issues/8457))
48 - Don't attempt to respond to some requests if the client has already disconnected. ([\#8465](https://github.com/matrix-org/synapse/issues/8465))
49 - Fix message duplication if something goes wrong after persisting the event. ([\#8476](https://github.com/matrix-org/synapse/issues/8476))
50 - Fix incremental sync returning an incorrect `prev_batch` token in timeline section, which when used to paginate returned events that were included in the incremental sync. Broken since v0.16.0. ([\#8486](https://github.com/matrix-org/synapse/issues/8486))
51 - Expose the `uk.half-shot.msc2778.login.application_service` to clients from the login API. This feature was added in v1.21.0, but was not exposed as a potential login flow. ([\#8504](https://github.com/matrix-org/synapse/issues/8504))
52 - Fix error code for `/profile/{userId}/displayname` to be `M_BAD_JSON`. ([\#8517](https://github.com/matrix-org/synapse/issues/8517))
53 - Fix a bug introduced in v1.7.0 that could cause Synapse to insert values from non-state `m.room.retention` events into the `room_retention` database table. ([\#8527](https://github.com/matrix-org/synapse/issues/8527))
54 - Fix not sending events over federation when using sharded event writers. ([\#8536](https://github.com/matrix-org/synapse/issues/8536))
55 - Fix a long standing bug where email notifications for encrypted messages were blank. ([\#8545](https://github.com/matrix-org/synapse/issues/8545))
56 - Fix increase in the number of `There was no active span...` errors logged when using OpenTracing. ([\#8567](https://github.com/matrix-org/synapse/issues/8567))
57 - Fix a bug that prevented errors encountered during execution of the `synapse_port_db` from being correctly printed. ([\#8585](https://github.com/matrix-org/synapse/issues/8585))
58 - Fix appservice transactions to only include a maximum of 100 persistent and 100 ephemeral events. ([\#8606](https://github.com/matrix-org/synapse/issues/8606))
59
60
61 Updates to the Docker image
62 ---------------------------
63
64 - Added multi-arch support (arm64,arm/v7) for the docker images. Contributed by @maquis196. ([\#7921](https://github.com/matrix-org/synapse/issues/7921))
65 - Add support for passing commandline args to the synapse process. Contributed by @samuel-p. ([\#8390](https://github.com/matrix-org/synapse/issues/8390))
66
67
68 Improved Documentation
69 ----------------------
70
71 - Update the directions for using the manhole with coroutines. ([\#8462](https://github.com/matrix-org/synapse/issues/8462))
72 - Improve readme by adding new shield.io badges. ([\#8493](https://github.com/matrix-org/synapse/issues/8493))
73 - Added note about docker in manhole.md regarding which ip address to bind to. Contributed by @Maquis196. ([\#8526](https://github.com/matrix-org/synapse/issues/8526))
74 - Document the new behaviour of the `allowed_lifetime_min` and `allowed_lifetime_max` settings in the room retention configuration. ([\#8529](https://github.com/matrix-org/synapse/issues/8529))
75
76
77 Deprecations and Removals
78 -------------------------
79
80 - Drop unused `device_max_stream_id` table. ([\#8589](https://github.com/matrix-org/synapse/issues/8589))
81
82
83 Internal Changes
84 ----------------
85
86 - Check for unreachable code with mypy. ([\#8432](https://github.com/matrix-org/synapse/issues/8432))
87 - Add unit test for event persister sharding. ([\#8433](https://github.com/matrix-org/synapse/issues/8433))
88 - Allow events to be sent to clients sooner when using sharded event persisters. ([\#8439](https://github.com/matrix-org/synapse/issues/8439), [\#8488](https://github.com/matrix-org/synapse/issues/8488), [\#8496](https://github.com/matrix-org/synapse/issues/8496), [\#8499](https://github.com/matrix-org/synapse/issues/8499))
89 - Configure `public_baseurl` when using demo scripts. ([\#8443](https://github.com/matrix-org/synapse/issues/8443))
90 - Add SQL logging on queries that happen during startup. ([\#8448](https://github.com/matrix-org/synapse/issues/8448))
91 - Speed up unit tests when using PostgreSQL. ([\#8450](https://github.com/matrix-org/synapse/issues/8450))
92 - Remove redundant database loads of stream_ordering for events we already have. ([\#8452](https://github.com/matrix-org/synapse/issues/8452))
93 - Reduce inconsistencies between codepaths for membership and non-membership events. ([\#8463](https://github.com/matrix-org/synapse/issues/8463))
94 - Combine `SpamCheckerApi` with the more generic `ModuleApi`. ([\#8464](https://github.com/matrix-org/synapse/issues/8464))
95 - Additional testing for `ThirdPartyEventRules`. ([\#8468](https://github.com/matrix-org/synapse/issues/8468))
96 - Add `-d` option to `./scripts-dev/lint.sh` to lint files that have changed since the last git commit. ([\#8472](https://github.com/matrix-org/synapse/issues/8472))
97 - Unblacklist some sytests. ([\#8474](https://github.com/matrix-org/synapse/issues/8474))
98 - Include the log level in the phone home stats. ([\#8477](https://github.com/matrix-org/synapse/issues/8477))
99 - Remove outdated sphinx documentation, scripts and configuration. ([\#8480](https://github.com/matrix-org/synapse/issues/8480))
100 - Clarify error message when plugin config parsers raise an error. ([\#8492](https://github.com/matrix-org/synapse/issues/8492))
101 - Remove the deprecated `Handlers` object. ([\#8494](https://github.com/matrix-org/synapse/issues/8494))
102 - Fix a threadsafety bug in unit tests. ([\#8497](https://github.com/matrix-org/synapse/issues/8497))
103 - Add user agent to user_daily_visits table. ([\#8503](https://github.com/matrix-org/synapse/issues/8503))
104 - Add type hints to various parts of the code base. ([\#8407](https://github.com/matrix-org/synapse/issues/8407), [\#8505](https://github.com/matrix-org/synapse/issues/8505), [\#8507](https://github.com/matrix-org/synapse/issues/8507), [\#8547](https://github.com/matrix-org/synapse/issues/8547), [\#8562](https://github.com/matrix-org/synapse/issues/8562), [\#8609](https://github.com/matrix-org/synapse/issues/8609))
105 - Remove unused code from the test framework. ([\#8514](https://github.com/matrix-org/synapse/issues/8514))
106 - Apply some internal fixes to the `HomeServer` class to make its code more idiomatic and statically-verifiable. ([\#8515](https://github.com/matrix-org/synapse/issues/8515))
107 - Factor out common code between `RoomMemberHandler._locally_reject_invite` and `EventCreationHandler.create_event`. ([\#8537](https://github.com/matrix-org/synapse/issues/8537))
108 - Improve database performance by executing more queries without starting transactions. ([\#8542](https://github.com/matrix-org/synapse/issues/8542))
109 - Rename `Cache` to `DeferredCache`, to better reflect its purpose. ([\#8548](https://github.com/matrix-org/synapse/issues/8548))
110 - Move metric registration code down into `LruCache`. ([\#8561](https://github.com/matrix-org/synapse/issues/8561), [\#8591](https://github.com/matrix-org/synapse/issues/8591))
111 - Replace `DeferredCache` with the lighter-weight `LruCache` where possible. ([\#8563](https://github.com/matrix-org/synapse/issues/8563))
112 - Add virtualenv-generated folders to `.gitignore`. ([\#8566](https://github.com/matrix-org/synapse/issues/8566))
113 - Add `get_immediate` method to `DeferredCache`. ([\#8568](https://github.com/matrix-org/synapse/issues/8568))
114 - Fix mypy not properly checking across the codebase, additionally, fix a typing assertion error in `handlers/auth.py`. ([\#8569](https://github.com/matrix-org/synapse/issues/8569))
115 - Fix `synmark` benchmark runner. ([\#8571](https://github.com/matrix-org/synapse/issues/8571))
116 - Modify `DeferredCache.get()` to return `Deferred`s instead of `ObservableDeferred`s. ([\#8572](https://github.com/matrix-org/synapse/issues/8572))
117 - Adjust a protocol-type definition to fit `sqlite3` assertions. ([\#8577](https://github.com/matrix-org/synapse/issues/8577))
118 - Support macOS on the `synmark` benchmark runner. ([\#8578](https://github.com/matrix-org/synapse/issues/8578))
119 - Update `mypy` static type checker to 0.790. ([\#8583](https://github.com/matrix-org/synapse/issues/8583), [\#8600](https://github.com/matrix-org/synapse/issues/8600))
120 - Re-organize the structured logging code to separate the TCP transport handling from the JSON formatting. ([\#8587](https://github.com/matrix-org/synapse/issues/8587))
121 - Remove extraneous unittest logging decorators from unit tests. ([\#8592](https://github.com/matrix-org/synapse/issues/8592))
122 - Minor optimisations in caching code. ([\#8593](https://github.com/matrix-org/synapse/issues/8593), [\#8594](https://github.com/matrix-org/synapse/issues/8594))
123
124
0125 Synapse 1.21.2 (2020-10-15)
1126 ===========================
2127
6161 ```
6262 ./scripts-dev/lint.sh path/to/file1.py path/to/file2.py path/to/folder
6363 ```
64
65 You can also provided the `-d` option, which will lint the files that have been
66 changed since the last git commit. This will often be significantly faster than
67 linting the whole codebase.
6468
6569 Before pushing new changes, ensure they don't produce linting errors. Commit any
6670 files that were corrected.
0 ================
1 Synapse |shield|
2 ================
3
4 .. |shield| image:: https://img.shields.io/matrix/synapse:matrix.org?label=support&logo=matrix
5 :alt: (get support on #synapse:matrix.org)
6 :target: https://matrix.to/#/#synapse:matrix.org
0 =========================================================
1 Synapse |support| |development| |license| |pypi| |python|
2 =========================================================
73
84 .. contents::
95
289285 Client-Server API are functioning correctly. See the `installation instructions
290286 <https://github.com/matrix-org/sytest#installing>`_ for details.
291287
292 Building Internal API Documentation
293 ===================================
294
295 Before building internal API documentation install sphinx and
296 sphinxcontrib-napoleon::
297
298 pip install sphinx
299 pip install sphinxcontrib-napoleon
300
301 Building internal API documentation::
302
303 python setup.py build_sphinx
304
305288 Troubleshooting
306289 ===============
307290
386369
387370 This is normally caused by a misconfiguration in your reverse-proxy. See
388371 `<docs/reverse_proxy.md>`_ and double-check that your settings are correct.
372
373 .. |support| image:: https://img.shields.io/matrix/synapse:matrix.org?label=support&logo=matrix
374 :alt: (get support on #synapse:matrix.org)
375 :target: https://matrix.to/#/#synapse:matrix.org
376
377 .. |development| image:: https://img.shields.io/matrix/synapse-dev:matrix.org?label=development&logo=matrix
378 :alt: (discuss development on #synapse-dev:matrix.org)
379 :target: https://matrix.to/#/#synapse-dev:matrix.org
380
381 .. |license| image:: https://img.shields.io/github/license/matrix-org/synapse
382 :alt: (check license in LICENSE file)
383 :target: LICENSE
384
385 .. |pypi| image:: https://img.shields.io/pypi/v/matrix-synapse
386 :alt: (latest version released on PyPi)
387 :target: https://pypi.org/project/matrix-synapse
388
389 .. |python| image:: https://img.shields.io/pypi/pyversions/matrix-synapse
390 :alt: (supported python versions)
391 :target: https://pypi.org/project/matrix-synapse
7373 # replace `1.3.0` and `stretch` accordingly:
7474 wget https://packages.matrix.org/debian/pool/main/m/matrix-synapse-py3/matrix-synapse-py3_1.3.0+stretch1_amd64.deb
7575 dpkg -i matrix-synapse-py3_1.3.0+stretch1_amd64.deb
76
77 Upgrading to v1.22.0
78 ====================
79
80 ThirdPartyEventRules breaking changes
81 -------------------------------------
82
83 This release introduces a backwards-incompatible change to modules making use of
84 ``ThirdPartyEventRules`` in Synapse. If you make use of a module defined under the
85 ``third_party_event_rules`` config option, please make sure it is updated to handle
86 the below change:
87
88 The ``http_client`` argument is no longer passed to modules as they are initialised. Instead,
89 modules are expected to make use of the ``http_client`` property on the ``ModuleApi`` class.
90 Modules are now passed a ``module_api`` argument during initialisation, which is an instance of
91 ``ModuleApi``. ``ModuleApi`` instances have a ``http_client`` property which acts the same as
92 the ``http_client`` argument previously passed to ``ThirdPartyEventRules`` modules.
7693
7794 Upgrading to v1.21.0
7895 ====================
0 matrix-synapse-py3 (1.22.1) stable; urgency=medium
1
2 * New synapse release 1.22.1.
3
4 -- Synapse Packaging team <packages@matrix.org> Fri, 30 Oct 2020 15:25:37 +0000
5
6 matrix-synapse-py3 (1.22.0) stable; urgency=medium
7
8 * New synapse release 1.22.0.
9
10 -- Synapse Packaging team <packages@matrix.org> Tue, 27 Oct 2020 12:07:12 +0000
11
012 matrix-synapse-py3 (1.21.2) stable; urgency=medium
113
214 [ Synapse Packaging team ]
2828
2929 if ! grep -F "Customisation made by demo/start.sh" -q $DIR/etc/$port.config; then
3030 printf '\n\n# Customisation made by demo/start.sh\n' >> $DIR/etc/$port.config
31
32 echo "public_baseurl: http://localhost:$port/" >> $DIR/etc/$port.config
3133
3234 echo 'enable_registration: true' >> $DIR/etc/$port.config
3335
8282 If all is well, you should now be able to connect to http://localhost:8008 and
8383 see a confirmation message.
8484
85 The following environment variables are supported in run mode:
85 The following environment variables are supported in `run` mode:
8686
8787 * `SYNAPSE_CONFIG_DIR`: where additional config files are stored. Defaults to
8888 `/data`.
9393 * `UID`, `GID`: the user and group id to run Synapse as. Defaults to `991`, `991`.
9494 * `TZ`: the [timezone](https://en.wikipedia.org/wiki/List_of_tz_database_time_zones) the container will run with. Defaults to `UTC`.
9595
96 For more complex setups (e.g. for workers) you can also pass your args directly to synapse using `run` mode. For example like this:
97
98 ```
99 docker run -d --name synapse \
100 --mount type=volume,src=synapse-data,dst=/data \
101 -p 8008:8008 \
102 matrixdotorg/synapse:latest run \
103 -m synapse.app.generic_worker \
104 --config-path=/data/homeserver.yaml \
105 --config-path=/data/generic_worker.yaml
106 ```
107
108 If you do not provide `-m`, the value of the `SYNAPSE_WORKER` environment variable is used. If you do not provide at least one `--config-path` or `-c`, the value of the `SYNAPSE_CONFIG_PATH` environment variable is used instead.
109
96110 ## Generating an (admin) user
97111
98112 After synapse is running, you may wish to create a user via `register_new_matrix_user`.
8989
9090 media_store_path: "/data/media"
9191 uploads_path: "/data/uploads"
92 max_upload_size: "{{ SYNAPSE_MAX_UPLOAD_SIZE or "10M" }}"
92 max_upload_size: "{{ SYNAPSE_MAX_UPLOAD_SIZE or "50M" }}"
9393 max_image_pixels: "32M"
9494 dynamic_thumbnails: false
9595
178178
179179
180180 def main(args, environ):
181 mode = args[1] if len(args) > 1 else None
181 mode = args[1] if len(args) > 1 else "run"
182182 desired_uid = int(environ.get("UID", "991"))
183183 desired_gid = int(environ.get("GID", "991"))
184184 synapse_worker = environ.get("SYNAPSE_WORKER", "synapse.app.homeserver")
204204 config_dir, config_path, environ, ownership
205205 )
206206
207 if mode is not None:
207 if mode != "run":
208208 error("Unknown execution mode '%s'" % (mode,))
209209
210 config_dir = environ.get("SYNAPSE_CONFIG_DIR", "/data")
211 config_path = environ.get("SYNAPSE_CONFIG_PATH", config_dir + "/homeserver.yaml")
212
213 if not os.path.exists(config_path):
214 if "SYNAPSE_SERVER_NAME" in environ:
215 error(
216 """\
210 args = args[2:]
211
212 if "-m" not in args:
213 args = ["-m", synapse_worker] + args
214
215 # if there are no config files passed to synapse, try adding the default file
216 if not any(p.startswith("--config-path") or p.startswith("-c") for p in args):
217 config_dir = environ.get("SYNAPSE_CONFIG_DIR", "/data")
218 config_path = environ.get(
219 "SYNAPSE_CONFIG_PATH", config_dir + "/homeserver.yaml"
220 )
221
222 if not os.path.exists(config_path):
223 if "SYNAPSE_SERVER_NAME" in environ:
224 error(
225 """\
217226 Config file '%s' does not exist.
218227
219228 The synapse docker image no longer supports generating a config file on-the-fly
220229 based on environment variables. You can migrate to a static config file by
221230 running with 'migrate_config'. See the README for more details.
222231 """
232 % (config_path,)
233 )
234
235 error(
236 "Config file '%s' does not exist. You should either create a new "
237 "config file by running with the `generate` argument (and then edit "
238 "the resulting file before restarting) or specify the path to an "
239 "existing config file with the SYNAPSE_CONFIG_PATH variable."
223240 % (config_path,)
224241 )
225242
226 error(
227 "Config file '%s' does not exist. You should either create a new "
228 "config file by running with the `generate` argument (and then edit "
229 "the resulting file before restarting) or specify the path to an "
230 "existing config file with the SYNAPSE_CONFIG_PATH variable."
231 % (config_path,)
232 )
233
234 log("Starting synapse with config file " + config_path)
235
236 args = ["python", "-m", synapse_worker, "--config-path", config_path]
243 args += ["--config-path", config_path]
244
245 log("Starting synapse with args " + " ".join(args))
246
247 args = ["python"] + args
237248 if ownership is not None:
238249 args = ["gosu", ownership] + args
239250 os.execv("/usr/sbin/gosu", args)
6363 - Use underscores for functions and variables.
6464 - **Docstrings**: should follow the [google code
6565 style](https://google.github.io/styleguide/pyguide.html#38-comments-and-docstrings).
66 This is so that we can generate documentation with
67 [sphinx](http://sphinxcontrib-napoleon.readthedocs.org/en/latest/).
6866 See the
6967 [examples](http://sphinxcontrib-napoleon.readthedocs.io/en/latest/example_google.html)
7068 in the sphinx documentation.
44 Synapse installation. This is a very powerful mechanism for administration and
55 debugging.
66
7 **_Security Warning_**
8
9 Note that this will give administrative access to synapse to **all users** with
10 shell access to the server. It should therefore **not** be enabled in
11 environments where untrusted users have shell access.
12
13 ***
14
715 To enable it, first uncomment the `manhole` listener configuration in
8 `homeserver.yaml`:
16 `homeserver.yaml`. The configuration is slightly different if you're using docker.
17
18 #### Docker config
19
20 If you are using Docker, set `bind_addresses` to `['0.0.0.0']` as shown:
21
22 ```yaml
23 listeners:
24 - port: 9000
25 bind_addresses: ['0.0.0.0']
26 type: manhole
27 ```
28
29 When using `docker run` to start the server, you will then need to change the command to the following to include the
30 `manhole` port forwarding. The `-p 127.0.0.1:9000:9000` below is important: it
31 ensures that access to the `manhole` is only possible for local users.
32
33 ```bash
34 docker run -d --name synapse \
35 --mount type=volume,src=synapse-data,dst=/data \
36 -p 8008:8008 \
37 -p 127.0.0.1:9000:9000 \
38 matrixdotorg/synapse:latest
39 ```
40
41 #### Native config
42
43 If you are not using docker, set `bind_addresses` to `['::1', '127.0.0.1']` as shown.
44 The `bind_addresses` in the example below is important: it ensures that access to the
45 `manhole` is only possible for local users).
946
1047 ```yaml
1148 listeners:
1451 type: manhole
1552 ```
1653
17 (`bind_addresses` in the above is important: it ensures that access to the
18 manhole is only possible for local users).
19
20 Note that this will give administrative access to synapse to **all users** with
21 shell access to the server. It should therefore **not** be enabled in
22 environments where untrusted users have shell access.
54 #### Accessing synapse manhole
2355
2456 Then restart synapse, and point an ssh client at port 9000 on localhost, using
2557 the username `matrix`:
3466 `synapse.server.HomeServer` object - which in turn gives access to many other
3567 parts of the process.
3668
69 Note that any call which returns a coroutine will need to be wrapped in `ensureDeferred`.
70
3771 As a simple example, retrieving an event from the database:
3872
39 ```
40 >>> hs.get_datastore().get_event('$1416420717069yeQaw:matrix.org')
73 ```pycon
74 >>> from twisted.internet import defer
75 >>> defer.ensureDeferred(hs.get_datastore().get_event('$1416420717069yeQaw:matrix.org'))
4176 <Deferred at 0x7ff253fc6998 current result: <FrozenEvent event_id='$1416420717069yeQaw:matrix.org', type='m.room.create', state_key=''>>
4277 ```
135135
136136 ### Lifetime limits
137137
138 **Note: this feature is mainly useful within a closed federation or on
139 servers that don't federate, because there currently is no way to
140 enforce these limits in an open federation.**
141
142 Server admins can restrict the values their local users are allowed to
143 use for both `min_lifetime` and `max_lifetime`. These limits can be
144 defined as such in the `retention` section of the configuration file:
138 Server admins can set limits on the values of `max_lifetime` to use when
139 purging old events in a room. These limits can be defined as such in the
140 `retention` section of the configuration file:
145141
146142 ```yaml
147143 allowed_lifetime_min: 1d
148144 allowed_lifetime_max: 1y
149145 ```
150146
151 Here, `allowed_lifetime_min` is the lowest value a local user can set
152 for both `min_lifetime` and `max_lifetime`, and `allowed_lifetime_max`
153 is the highest value. Both parameters are optional (e.g. setting
154 `allowed_lifetime_min` but not `allowed_lifetime_max` only enforces a
155 minimum and no maximum).
147 The limits are considered when running purge jobs. If necessary, the
148 effective value of `max_lifetime` will be brought between
149 `allowed_lifetime_min` and `allowed_lifetime_max` (inclusive).
150 This means that, if the value of `max_lifetime` defined in the room's state
151 is lower than `allowed_lifetime_min`, the value of `allowed_lifetime_min`
152 will be used instead. Likewise, if the value of `max_lifetime` is higher
153 than `allowed_lifetime_max`, the value of `allowed_lifetime_max` will be
154 used instead.
155
156 In the example above, we ensure Synapse never deletes events that are less
157 than one day old, and that it always deletes events that are over a year
158 old.
159
160 If a default policy is set, and its `max_lifetime` value is lower than
161 `allowed_lifetime_min` or higher than `allowed_lifetime_max`, the same
162 process applies.
163
164 Both parameters are optional; if one is omitted Synapse won't use it to
165 adjust the effective value of `max_lifetime`.
156166
157167 Like other settings in this section, these parameters can be expressed
158168 either as a duration or as a number of milliseconds.
237237
238238 ```yaml
239239 oidc_config:
240 enabled: true
241 issuer: "https://id.twitch.tv/oauth2/"
242 client_id: "your-client-id" # TO BE FILLED
243 client_secret: "your-client-secret" # TO BE FILLED
244 client_auth_method: "client_secret_post"
245 user_mapping_provider:
246 config:
247 localpart_template: '{{ user.preferred_username }}'
248 display_name_template: '{{ user.name }}'
249 ```
240 enabled: true
241 issuer: "https://id.twitch.tv/oauth2/"
242 client_id: "your-client-id" # TO BE FILLED
243 client_secret: "your-client-secret" # TO BE FILLED
244 client_auth_method: "client_secret_post"
245 user_mapping_provider:
246 config:
247 localpart_template: "{{ user.preferred_username }}"
248 display_name_template: "{{ user.name }}"
249 ```
250
251 ### GitLab
252
253 1. Create a [new application](https://gitlab.com/profile/applications).
254 2. Add the `read_user` and `openid` scopes.
255 3. Add this Callback URL: `[synapse public baseurl]/_synapse/oidc/callback`
256
257 Synapse config:
258
259 ```yaml
260 oidc_config:
261 enabled: true
262 issuer: "https://gitlab.com/"
263 client_id: "your-client-id" # TO BE FILLED
264 client_secret: "your-client-secret" # TO BE FILLED
265 client_auth_method: "client_secret_post"
266 scopes: ["openid", "read_user"]
267 user_profile_method: "userinfo_endpoint"
268 user_mapping_provider:
269 config:
270 localpart_template: '{{ user.nickname }}'
271 display_name_template: '{{ user.name }}'
272 ```
5353 proxy_set_header X-Forwarded-For $remote_addr;
5454 # Nginx by default only allows file uploads up to 1M in size
5555 # Increase client_max_body_size to match max_upload_size defined in homeserver.yaml
56 client_max_body_size 10M;
56 client_max_body_size 50M;
5757 }
5858 }
5959 ```
118118 # For example, for room version 1, default_room_version should be set
119119 # to "1".
120120 #
121 #default_room_version: "5"
121 #default_room_version: "6"
122122
123123 # The GC threshold parameters to pass to `gc.set_threshold`, if defined
124124 #
892892
893893 # The largest allowed upload size in bytes
894894 #
895 #max_upload_size: 10M
895 #max_upload_size: 50M
896896
897897 # Maximum number of pixels that will be thumbnailed
898898 #
17131713 #
17141714 #skip_verification: true
17151715
1716 # Whether to fetch the user profile from the userinfo endpoint. Valid
1717 # values are: "auto" or "userinfo_endpoint".
1718 #
1719 # Defaults to "auto", which fetches the userinfo endpoint if "openid" is included
1720 # in `scopes`. Uncomment the following to always fetch the userinfo endpoint.
1721 #
1722 #user_profile_method: "userinfo_endpoint"
1723
17161724 # Uncomment to allow a user logging in via OIDC to match a pre-existing account instead
17171725 # of failing. This could be used if switching from password logins to OIDC. Defaults to false.
17181726 #
24952503 # events: worker1
24962504 # typing: worker1
24972505
2506 # The worker that is used to run background tasks (e.g. cleaning up expired
2507 # data). If not provided this defaults to the main process.
2508 #
2509 #run_background_tasks_on: worker1
2510
24982511
24992512 # Configuration for Redis when using workers. This *must* be enabled when
25002513 # using workers (unless using old style direct TCP configuration).
1010 The Python class is instantiated with two objects:
1111
1212 * Any configuration (see below).
13 * An instance of `synapse.spam_checker_api.SpamCheckerApi`.
13 * An instance of `synapse.module_api.ModuleApi`.
1414
1515 It then implements methods which return a boolean to alter behavior in Synapse.
1616
2525 The details of the each of these methods (as well as their inputs and outputs)
2626 are documented in the `synapse.events.spamcheck.SpamChecker` class.
2727
28 The `SpamCheckerApi` class provides a way for the custom spam checker class to
29 call back into the homeserver internals. It currently implements the following
30 methods:
31
32 * `get_state_events_in_room`
28 The `ModuleApi` class provides a way for the custom spam checker class to
29 call back into the homeserver internals.
3330
3431 ### Example
3532
+0
-1
docs/sphinx/README.rst less more
0 TODO: how (if at all) is this actually maintained?
+0
-271
docs/sphinx/conf.py less more
0 # -*- coding: utf-8 -*-
1 #
2 # Synapse documentation build configuration file, created by
3 # sphinx-quickstart on Tue Jun 10 17:31:02 2014.
4 #
5 # This file is execfile()d with the current directory set to its
6 # containing dir.
7 #
8 # Note that not all possible configuration values are present in this
9 # autogenerated file.
10 #
11 # All configuration values have a default; values that are commented out
12 # serve to show the default.
13
14 import sys
15 import os
16
17 # If extensions (or modules to document with autodoc) are in another directory,
18 # add these directories to sys.path here. If the directory is relative to the
19 # documentation root, use os.path.abspath to make it absolute, like shown here.
20 sys.path.insert(0, os.path.abspath(".."))
21
22 # -- General configuration ------------------------------------------------
23
24 # If your documentation needs a minimal Sphinx version, state it here.
25 # needs_sphinx = '1.0'
26
27 # Add any Sphinx extension module names here, as strings. They can be
28 # extensions coming with Sphinx (named 'sphinx.ext.*') or your custom
29 # ones.
30 extensions = [
31 "sphinx.ext.autodoc",
32 "sphinx.ext.intersphinx",
33 "sphinx.ext.coverage",
34 "sphinx.ext.ifconfig",
35 "sphinxcontrib.napoleon",
36 ]
37
38 # Add any paths that contain templates here, relative to this directory.
39 templates_path = ["_templates"]
40
41 # The suffix of source filenames.
42 source_suffix = ".rst"
43
44 # The encoding of source files.
45 # source_encoding = 'utf-8-sig'
46
47 # The master toctree document.
48 master_doc = "index"
49
50 # General information about the project.
51 project = "Synapse"
52 copyright = (
53 "Copyright 2014-2017 OpenMarket Ltd, 2017 Vector Creations Ltd, 2017 New Vector Ltd"
54 )
55
56 # The version info for the project you're documenting, acts as replacement for
57 # |version| and |release|, also used in various other places throughout the
58 # built documents.
59 #
60 # The short X.Y version.
61 version = "1.0"
62 # The full version, including alpha/beta/rc tags.
63 release = "1.0"
64
65 # The language for content autogenerated by Sphinx. Refer to documentation
66 # for a list of supported languages.
67 # language = None
68
69 # There are two options for replacing |today|: either, you set today to some
70 # non-false value, then it is used:
71 # today = ''
72 # Else, today_fmt is used as the format for a strftime call.
73 # today_fmt = '%B %d, %Y'
74
75 # List of patterns, relative to source directory, that match files and
76 # directories to ignore when looking for source files.
77 exclude_patterns = ["_build"]
78
79 # The reST default role (used for this markup: `text`) to use for all
80 # documents.
81 # default_role = None
82
83 # If true, '()' will be appended to :func: etc. cross-reference text.
84 # add_function_parentheses = True
85
86 # If true, the current module name will be prepended to all description
87 # unit titles (such as .. function::).
88 # add_module_names = True
89
90 # If true, sectionauthor and moduleauthor directives will be shown in the
91 # output. They are ignored by default.
92 # show_authors = False
93
94 # The name of the Pygments (syntax highlighting) style to use.
95 pygments_style = "sphinx"
96
97 # A list of ignored prefixes for module index sorting.
98 # modindex_common_prefix = []
99
100 # If true, keep warnings as "system message" paragraphs in the built documents.
101 # keep_warnings = False
102
103
104 # -- Options for HTML output ----------------------------------------------
105
106 # The theme to use for HTML and HTML Help pages. See the documentation for
107 # a list of builtin themes.
108 html_theme = "default"
109
110 # Theme options are theme-specific and customize the look and feel of a theme
111 # further. For a list of options available for each theme, see the
112 # documentation.
113 # html_theme_options = {}
114
115 # Add any paths that contain custom themes here, relative to this directory.
116 # html_theme_path = []
117
118 # The name for this set of Sphinx documents. If None, it defaults to
119 # "<project> v<release> documentation".
120 # html_title = None
121
122 # A shorter title for the navigation bar. Default is the same as html_title.
123 # html_short_title = None
124
125 # The name of an image file (relative to this directory) to place at the top
126 # of the sidebar.
127 # html_logo = None
128
129 # The name of an image file (within the static path) to use as favicon of the
130 # docs. This file should be a Windows icon file (.ico) being 16x16 or 32x32
131 # pixels large.
132 # html_favicon = None
133
134 # Add any paths that contain custom static files (such as style sheets) here,
135 # relative to this directory. They are copied after the builtin static files,
136 # so a file named "default.css" will overwrite the builtin "default.css".
137 html_static_path = ["_static"]
138
139 # Add any extra paths that contain custom files (such as robots.txt or
140 # .htaccess) here, relative to this directory. These files are copied
141 # directly to the root of the documentation.
142 # html_extra_path = []
143
144 # If not '', a 'Last updated on:' timestamp is inserted at every page bottom,
145 # using the given strftime format.
146 # html_last_updated_fmt = '%b %d, %Y'
147
148 # If true, SmartyPants will be used to convert quotes and dashes to
149 # typographically correct entities.
150 # html_use_smartypants = True
151
152 # Custom sidebar templates, maps document names to template names.
153 # html_sidebars = {}
154
155 # Additional templates that should be rendered to pages, maps page names to
156 # template names.
157 # html_additional_pages = {}
158
159 # If false, no module index is generated.
160 # html_domain_indices = True
161
162 # If false, no index is generated.
163 # html_use_index = True
164
165 # If true, the index is split into individual pages for each letter.
166 # html_split_index = False
167
168 # If true, links to the reST sources are added to the pages.
169 # html_show_sourcelink = True
170
171 # If true, "Created using Sphinx" is shown in the HTML footer. Default is True.
172 # html_show_sphinx = True
173
174 # If true, "(C) Copyright ..." is shown in the HTML footer. Default is True.
175 # html_show_copyright = True
176
177 # If true, an OpenSearch description file will be output, and all pages will
178 # contain a <link> tag referring to it. The value of this option must be the
179 # base URL from which the finished HTML is served.
180 # html_use_opensearch = ''
181
182 # This is the file name suffix for HTML files (e.g. ".xhtml").
183 # html_file_suffix = None
184
185 # Output file base name for HTML help builder.
186 htmlhelp_basename = "Synapsedoc"
187
188
189 # -- Options for LaTeX output ---------------------------------------------
190
191 latex_elements = {
192 # The paper size ('letterpaper' or 'a4paper').
193 #'papersize': 'letterpaper',
194 # The font size ('10pt', '11pt' or '12pt').
195 #'pointsize': '10pt',
196 # Additional stuff for the LaTeX preamble.
197 #'preamble': '',
198 }
199
200 # Grouping the document tree into LaTeX files. List of tuples
201 # (source start file, target name, title,
202 # author, documentclass [howto, manual, or own class]).
203 latex_documents = [("index", "Synapse.tex", "Synapse Documentation", "TNG", "manual")]
204
205 # The name of an image file (relative to this directory) to place at the top of
206 # the title page.
207 # latex_logo = None
208
209 # For "manual" documents, if this is true, then toplevel headings are parts,
210 # not chapters.
211 # latex_use_parts = False
212
213 # If true, show page references after internal links.
214 # latex_show_pagerefs = False
215
216 # If true, show URL addresses after external links.
217 # latex_show_urls = False
218
219 # Documents to append as an appendix to all manuals.
220 # latex_appendices = []
221
222 # If false, no module index is generated.
223 # latex_domain_indices = True
224
225
226 # -- Options for manual page output ---------------------------------------
227
228 # One entry per manual page. List of tuples
229 # (source start file, name, description, authors, manual section).
230 man_pages = [("index", "synapse", "Synapse Documentation", ["TNG"], 1)]
231
232 # If true, show URL addresses after external links.
233 # man_show_urls = False
234
235
236 # -- Options for Texinfo output -------------------------------------------
237
238 # Grouping the document tree into Texinfo files. List of tuples
239 # (source start file, target name, title, author,
240 # dir menu entry, description, category)
241 texinfo_documents = [
242 (
243 "index",
244 "Synapse",
245 "Synapse Documentation",
246 "TNG",
247 "Synapse",
248 "One line description of project.",
249 "Miscellaneous",
250 )
251 ]
252
253 # Documents to append as an appendix to all manuals.
254 # texinfo_appendices = []
255
256 # If false, no module index is generated.
257 # texinfo_domain_indices = True
258
259 # How to display URL addresses: 'footnote', 'no', or 'inline'.
260 # texinfo_show_urls = 'footnote'
261
262 # If true, do not generate a @detailmenu in the "Top" node's menu.
263 # texinfo_no_detailmenu = False
264
265
266 # Example configuration for intersphinx: refer to the Python standard library.
267 intersphinx_mapping = {"http://docs.python.org/": None}
268
269 napoleon_include_special_with_doc = True
270 napoleon_use_ivar = True
+0
-20
docs/sphinx/index.rst less more
0 .. Synapse documentation master file, created by
1 sphinx-quickstart on Tue Jun 10 17:31:02 2014.
2 You can adapt this file completely to your liking, but it should at least
3 contain the root `toctree` directive.
4
5 Welcome to Synapse's documentation!
6 ===================================
7
8 Contents:
9
10 .. toctree::
11 synapse
12
13 Indices and tables
14 ==================
15
16 * :ref:`genindex`
17 * :ref:`modindex`
18 * :ref:`search`
19
+0
-7
docs/sphinx/modules.rst less more
0 synapse
1 =======
2
3 .. toctree::
4 :maxdepth: 4
5
6 synapse
+0
-7
docs/sphinx/synapse.api.auth.rst less more
0 synapse.api.auth module
1 =======================
2
3 .. automodule:: synapse.api.auth
4 :members:
5 :undoc-members:
6 :show-inheritance:
+0
-7
docs/sphinx/synapse.api.constants.rst less more
0 synapse.api.constants module
1 ============================
2
3 .. automodule:: synapse.api.constants
4 :members:
5 :undoc-members:
6 :show-inheritance:
+0
-7
docs/sphinx/synapse.api.dbobjects.rst less more
0 synapse.api.dbobjects module
1 ============================
2
3 .. automodule:: synapse.api.dbobjects
4 :members:
5 :undoc-members:
6 :show-inheritance:
+0
-7
docs/sphinx/synapse.api.errors.rst less more
0 synapse.api.errors module
1 =========================
2
3 .. automodule:: synapse.api.errors
4 :members:
5 :undoc-members:
6 :show-inheritance:
+0
-7
docs/sphinx/synapse.api.event_stream.rst less more
0 synapse.api.event_stream module
1 ===============================
2
3 .. automodule:: synapse.api.event_stream
4 :members:
5 :undoc-members:
6 :show-inheritance:
+0
-7
docs/sphinx/synapse.api.events.factory.rst less more
0 synapse.api.events.factory module
1 =================================
2
3 .. automodule:: synapse.api.events.factory
4 :members:
5 :undoc-members:
6 :show-inheritance:
+0
-7
docs/sphinx/synapse.api.events.room.rst less more
0 synapse.api.events.room module
1 ==============================
2
3 .. automodule:: synapse.api.events.room
4 :members:
5 :undoc-members:
6 :show-inheritance:
+0
-18
docs/sphinx/synapse.api.events.rst less more
0 synapse.api.events package
1 ==========================
2
3 Submodules
4 ----------
5
6 .. toctree::
7
8 synapse.api.events.factory
9 synapse.api.events.room
10
11 Module contents
12 ---------------
13
14 .. automodule:: synapse.api.events
15 :members:
16 :undoc-members:
17 :show-inheritance:
+0
-7
docs/sphinx/synapse.api.handlers.events.rst less more
0 synapse.api.handlers.events module
1 ==================================
2
3 .. automodule:: synapse.api.handlers.events
4 :members:
5 :undoc-members:
6 :show-inheritance:
+0
-7
docs/sphinx/synapse.api.handlers.factory.rst less more
0 synapse.api.handlers.factory module
1 ===================================
2
3 .. automodule:: synapse.api.handlers.factory
4 :members:
5 :undoc-members:
6 :show-inheritance:
+0
-7
docs/sphinx/synapse.api.handlers.federation.rst less more
0 synapse.api.handlers.federation module
1 ======================================
2
3 .. automodule:: synapse.api.handlers.federation
4 :members:
5 :undoc-members:
6 :show-inheritance:
+0
-7
docs/sphinx/synapse.api.handlers.register.rst less more
0 synapse.api.handlers.register module
1 ====================================
2
3 .. automodule:: synapse.api.handlers.register
4 :members:
5 :undoc-members:
6 :show-inheritance:
+0
-7
docs/sphinx/synapse.api.handlers.room.rst less more
0 synapse.api.handlers.room module
1 ================================
2
3 .. automodule:: synapse.api.handlers.room
4 :members:
5 :undoc-members:
6 :show-inheritance:
+0
-21
docs/sphinx/synapse.api.handlers.rst less more
0 synapse.api.handlers package
1 ============================
2
3 Submodules
4 ----------
5
6 .. toctree::
7
8 synapse.api.handlers.events
9 synapse.api.handlers.factory
10 synapse.api.handlers.federation
11 synapse.api.handlers.register
12 synapse.api.handlers.room
13
14 Module contents
15 ---------------
16
17 .. automodule:: synapse.api.handlers
18 :members:
19 :undoc-members:
20 :show-inheritance:
+0
-7
docs/sphinx/synapse.api.notifier.rst less more
0 synapse.api.notifier module
1 ===========================
2
3 .. automodule:: synapse.api.notifier
4 :members:
5 :undoc-members:
6 :show-inheritance:
+0
-7
docs/sphinx/synapse.api.register_events.rst less more
0 synapse.api.register_events module
1 ==================================
2
3 .. automodule:: synapse.api.register_events
4 :members:
5 :undoc-members:
6 :show-inheritance:
+0
-7
docs/sphinx/synapse.api.room_events.rst less more
0 synapse.api.room_events module
1 ==============================
2
3 .. automodule:: synapse.api.room_events
4 :members:
5 :undoc-members:
6 :show-inheritance:
+0
-30
docs/sphinx/synapse.api.rst less more
0 synapse.api package
1 ===================
2
3 Subpackages
4 -----------
5
6 .. toctree::
7
8 synapse.api.events
9 synapse.api.handlers
10 synapse.api.streams
11
12 Submodules
13 ----------
14
15 .. toctree::
16
17 synapse.api.auth
18 synapse.api.constants
19 synapse.api.errors
20 synapse.api.notifier
21 synapse.api.storage
22
23 Module contents
24 ---------------
25
26 .. automodule:: synapse.api
27 :members:
28 :undoc-members:
29 :show-inheritance:
+0
-7
docs/sphinx/synapse.api.server.rst less more
0 synapse.api.server module
1 =========================
2
3 .. automodule:: synapse.api.server
4 :members:
5 :undoc-members:
6 :show-inheritance:
+0
-7
docs/sphinx/synapse.api.storage.rst less more
0 synapse.api.storage module
1 ==========================
2
3 .. automodule:: synapse.api.storage
4 :members:
5 :undoc-members:
6 :show-inheritance:
+0
-7
docs/sphinx/synapse.api.stream.rst less more
0 synapse.api.stream module
1 =========================
2
3 .. automodule:: synapse.api.stream
4 :members:
5 :undoc-members:
6 :show-inheritance:
+0
-7
docs/sphinx/synapse.api.streams.event.rst less more
0 synapse.api.streams.event module
1 ================================
2
3 .. automodule:: synapse.api.streams.event
4 :members:
5 :undoc-members:
6 :show-inheritance:
+0
-17
docs/sphinx/synapse.api.streams.rst less more
0 synapse.api.streams package
1 ===========================
2
3 Submodules
4 ----------
5
6 .. toctree::
7
8 synapse.api.streams.event
9
10 Module contents
11 ---------------
12
13 .. automodule:: synapse.api.streams
14 :members:
15 :undoc-members:
16 :show-inheritance:
+0
-7
docs/sphinx/synapse.app.homeserver.rst less more
0 synapse.app.homeserver module
1 =============================
2
3 .. automodule:: synapse.app.homeserver
4 :members:
5 :undoc-members:
6 :show-inheritance:
+0
-17
docs/sphinx/synapse.app.rst less more
0 synapse.app package
1 ===================
2
3 Submodules
4 ----------
5
6 .. toctree::
7
8 synapse.app.homeserver
9
10 Module contents
11 ---------------
12
13 .. automodule:: synapse.app
14 :members:
15 :undoc-members:
16 :show-inheritance:
+0
-10
docs/sphinx/synapse.db.rst less more
0 synapse.db package
1 ==================
2
3 Module contents
4 ---------------
5
6 .. automodule:: synapse.db
7 :members:
8 :undoc-members:
9 :show-inheritance:
+0
-7
docs/sphinx/synapse.federation.handler.rst less more
0 synapse.federation.handler module
1 =================================
2
3 .. automodule:: synapse.federation.handler
4 :members:
5 :undoc-members:
6 :show-inheritance:
+0
-7
docs/sphinx/synapse.federation.messaging.rst less more
0 synapse.federation.messaging module
1 ===================================
2
3 .. automodule:: synapse.federation.messaging
4 :members:
5 :undoc-members:
6 :show-inheritance:
+0
-7
docs/sphinx/synapse.federation.pdu_codec.rst less more
0 synapse.federation.pdu_codec module
1 ===================================
2
3 .. automodule:: synapse.federation.pdu_codec
4 :members:
5 :undoc-members:
6 :show-inheritance:
+0
-7
docs/sphinx/synapse.federation.persistence.rst less more
0 synapse.federation.persistence module
1 =====================================
2
3 .. automodule:: synapse.federation.persistence
4 :members:
5 :undoc-members:
6 :show-inheritance:
+0
-7
docs/sphinx/synapse.federation.replication.rst less more
0 synapse.federation.replication module
1 =====================================
2
3 .. automodule:: synapse.federation.replication
4 :members:
5 :undoc-members:
6 :show-inheritance:
+0
-22
docs/sphinx/synapse.federation.rst less more
0 synapse.federation package
1 ==========================
2
3 Submodules
4 ----------
5
6 .. toctree::
7
8 synapse.federation.handler
9 synapse.federation.pdu_codec
10 synapse.federation.persistence
11 synapse.federation.replication
12 synapse.federation.transport
13 synapse.federation.units
14
15 Module contents
16 ---------------
17
18 .. automodule:: synapse.federation
19 :members:
20 :undoc-members:
21 :show-inheritance:
+0
-7
docs/sphinx/synapse.federation.transport.rst less more
0 synapse.federation.transport module
1 ===================================
2
3 .. automodule:: synapse.federation.transport
4 :members:
5 :undoc-members:
6 :show-inheritance:
+0
-7
docs/sphinx/synapse.federation.units.rst less more
0 synapse.federation.units module
1 ===============================
2
3 .. automodule:: synapse.federation.units
4 :members:
5 :undoc-members:
6 :show-inheritance:
+0
-19
docs/sphinx/synapse.persistence.rst less more
0 synapse.persistence package
1 ===========================
2
3 Submodules
4 ----------
5
6 .. toctree::
7
8 synapse.persistence.service
9 synapse.persistence.tables
10 synapse.persistence.transactions
11
12 Module contents
13 ---------------
14
15 .. automodule:: synapse.persistence
16 :members:
17 :undoc-members:
18 :show-inheritance:
+0
-7
docs/sphinx/synapse.persistence.service.rst less more
0 synapse.persistence.service module
1 ==================================
2
3 .. automodule:: synapse.persistence.service
4 :members:
5 :undoc-members:
6 :show-inheritance:
+0
-7
docs/sphinx/synapse.persistence.tables.rst less more
0 synapse.persistence.tables module
1 =================================
2
3 .. automodule:: synapse.persistence.tables
4 :members:
5 :undoc-members:
6 :show-inheritance:
+0
-7
docs/sphinx/synapse.persistence.transactions.rst less more
0 synapse.persistence.transactions module
1 =======================================
2
3 .. automodule:: synapse.persistence.transactions
4 :members:
5 :undoc-members:
6 :show-inheritance:
+0
-7
docs/sphinx/synapse.rest.base.rst less more
0 synapse.rest.base module
1 ========================
2
3 .. automodule:: synapse.rest.base
4 :members:
5 :undoc-members:
6 :show-inheritance:
+0
-7
docs/sphinx/synapse.rest.events.rst less more
0 synapse.rest.events module
1 ==========================
2
3 .. automodule:: synapse.rest.events
4 :members:
5 :undoc-members:
6 :show-inheritance:
+0
-7
docs/sphinx/synapse.rest.register.rst less more
0 synapse.rest.register module
1 ============================
2
3 .. automodule:: synapse.rest.register
4 :members:
5 :undoc-members:
6 :show-inheritance:
+0
-7
docs/sphinx/synapse.rest.room.rst less more
0 synapse.rest.room module
1 ========================
2
3 .. automodule:: synapse.rest.room
4 :members:
5 :undoc-members:
6 :show-inheritance:
+0
-20
docs/sphinx/synapse.rest.rst less more
0 synapse.rest package
1 ====================
2
3 Submodules
4 ----------
5
6 .. toctree::
7
8 synapse.rest.base
9 synapse.rest.events
10 synapse.rest.register
11 synapse.rest.room
12
13 Module contents
14 ---------------
15
16 .. automodule:: synapse.rest
17 :members:
18 :undoc-members:
19 :show-inheritance:
+0
-30
docs/sphinx/synapse.rst less more
0 synapse package
1 ===============
2
3 Subpackages
4 -----------
5
6 .. toctree::
7
8 synapse.api
9 synapse.app
10 synapse.federation
11 synapse.persistence
12 synapse.rest
13 synapse.util
14
15 Submodules
16 ----------
17
18 .. toctree::
19
20 synapse.server
21 synapse.state
22
23 Module contents
24 ---------------
25
26 .. automodule:: synapse
27 :members:
28 :undoc-members:
29 :show-inheritance:
+0
-7
docs/sphinx/synapse.server.rst less more
0 synapse.server module
1 =====================
2
3 .. automodule:: synapse.server
4 :members:
5 :undoc-members:
6 :show-inheritance:
+0
-7
docs/sphinx/synapse.state.rst less more
0 synapse.state module
1 ====================
2
3 .. automodule:: synapse.state
4 :members:
5 :undoc-members:
6 :show-inheritance:
+0
-7
docs/sphinx/synapse.util.async.rst less more
0 synapse.util.async module
1 =========================
2
3 .. automodule:: synapse.util.async
4 :members:
5 :undoc-members:
6 :show-inheritance:
+0
-7
docs/sphinx/synapse.util.dbutils.rst less more
0 synapse.util.dbutils module
1 ===========================
2
3 .. automodule:: synapse.util.dbutils
4 :members:
5 :undoc-members:
6 :show-inheritance:
+0
-7
docs/sphinx/synapse.util.http.rst less more
0 synapse.util.http module
1 ========================
2
3 .. automodule:: synapse.util.http
4 :members:
5 :undoc-members:
6 :show-inheritance:
+0
-7
docs/sphinx/synapse.util.lockutils.rst less more
0 synapse.util.lockutils module
1 =============================
2
3 .. automodule:: synapse.util.lockutils
4 :members:
5 :undoc-members:
6 :show-inheritance:
+0
-7
docs/sphinx/synapse.util.logutils.rst less more
0 synapse.util.logutils module
1 ============================
2
3 .. automodule:: synapse.util.logutils
4 :members:
5 :undoc-members:
6 :show-inheritance:
+0
-21
docs/sphinx/synapse.util.rst less more
0 synapse.util package
1 ====================
2
3 Submodules
4 ----------
5
6 .. toctree::
7
8 synapse.util.async
9 synapse.util.http
10 synapse.util.lockutils
11 synapse.util.logutils
12 synapse.util.stringutils
13
14 Module contents
15 ---------------
16
17 .. automodule:: synapse.util
18 :members:
19 :undoc-members:
20 :show-inheritance:
+0
-7
docs/sphinx/synapse.util.stringutils.rst less more
0 synapse.util.stringutils module
1 ===============================
2
3 .. automodule:: synapse.util.stringutils
4 :members:
5 :undoc-members:
6 :show-inheritance:
1414
1515 > SERVER example.com
1616 < REPLICATE
17 > POSITION events master 53
17 > POSITION events master 53 53
1818 > RDATA events master 54 ["$foo1:bar.com", ...]
1919 > RDATA events master 55 ["$foo4:bar.com", ...]
2020
137137 < NAME synapse.app.appservice
138138 < PING 1490197665618
139139 < REPLICATE
140 > POSITION events master 1
141 > POSITION backfill master 1
142 > POSITION caches master 1
140 > POSITION events master 1 1
141 > POSITION backfill master 1 1
142 > POSITION caches master 1 1
143143 > RDATA caches master 2 ["get_user_by_id",["@01register-user:localhost:8823"],1490197670513]
144144 > RDATA events master 14 ["$149019767112vOHxz:localhost:8823",
145145 "!AFDCvgApUmpdfVjIXm:localhost:8823","m.room.guest_access","",null]
184184 updates via HTTP API, rather than via the DB, then processes should make the
185185 request to the appropriate process.
186186
187 Two positions are included, the "new" position and the last position sent respectively.
188 This allows servers to tell instances that the positions have advanced but no
189 data has been written, without clients needlessly checking to see if they
190 have missed any updates.
191
187192 #### ERROR (S, C)
188193
189194 There was an error
318318 events: event_persister1
319319 ```
320320
321 #### Background tasks
322
323 There is also *experimental* support for moving background tasks to a separate
324 worker. Background tasks are run periodically or started via replication. Exactly
325 which tasks are configured to run depends on your Synapse configuration (e.g. if
326 stats is enabled).
327
328 To enable this, the worker must have a `worker_name` and can be configured to run
329 background tasks. For example, to move background tasks to a dedicated worker,
330 the shared configuration would include:
331
332 ```yaml
333 run_background_tasks_on: background_worker
334 ```
335
336 You might also wish to investigate the `update_user_directory` and
337 `media_instance_running_background_jobs` settings.
321338
322339 ### `synapse.app.pusher`
323340
55 show_error_codes = True
66 show_traceback = True
77 mypy_path = stubs
8 warn_unreachable = True
89 files =
910 synapse/api,
1011 synapse/appservice,
1314 synapse/events/builder.py,
1415 synapse/events/spamcheck.py,
1516 synapse/federation,
17 synapse/handlers/_base.py,
18 synapse/handlers/account_data.py,
19 synapse/handlers/appservice.py,
1620 synapse/handlers/auth.py,
1721 synapse/handlers/cas_handler.py,
22 synapse/handlers/deactivate_account.py,
23 synapse/handlers/device.py,
24 synapse/handlers/devicemessage.py,
1825 synapse/handlers/directory.py,
1926 synapse/handlers/events.py,
2027 synapse/handlers/federation.py,
2330 synapse/handlers/message.py,
2431 synapse/handlers/oidc_handler.py,
2532 synapse/handlers/pagination.py,
33 synapse/handlers/password_policy.py,
2634 synapse/handlers/presence.py,
35 synapse/handlers/profile.py,
36 synapse/handlers/read_marker.py,
2737 synapse/handlers/room.py,
2838 synapse/handlers/room_member.py,
2939 synapse/handlers/room_member_worker.py,
5666 synapse/streams,
5767 synapse/types.py,
5868 synapse/util/async_helpers.py,
59 synapse/util/caches/descriptors.py,
60 synapse/util/caches/stream_change_cache.py,
69 synapse/util/caches,
6170 synapse/util/metrics.py,
6271 tests/replication,
6372 tests/test_utils,
141150
142151 [mypy-nacl.*]
143152 ignore_missing_imports = True
153
154 [mypy-hiredis]
155 ignore_missing_imports = True
2121 import sys
2222 import time
2323 import traceback
24 from typing import Optional
2425
2526 import yaml
2627
8990 "room_stats_state": ["is_federatable"],
9091 "local_media_repository": ["safe_from_quarantine"],
9192 "users": ["shadow_banned"],
93 "e2e_fallback_keys_json": ["used"],
9294 }
9395
9496
150152
151153 # Error returned by the run function. Used at the top-level part of the script to
152154 # handle errors and return codes.
153 end_error = None
155 end_error = None # type: Optional[str]
154156 # The exec_info for the error, if any. If error is defined but not exec_info the script
155157 # will show only the error message without the stacktrace, if exec_info is defined but
156158 # not the error then the script will show nothing outside of what's printed in the run
488490
489491 hs = MockHomeserver(self.hs_config)
490492
491 with make_conn(db_config, engine) as db_conn:
493 with make_conn(db_config, engine, "portdb") as db_conn:
492494 engine.check_database(
493495 db_conn, allow_outdated_version=allow_outdated_version
494496 )
633635 self.progress.done()
634636 except Exception as e:
635637 global end_error_exec_info
636 end_error = e
638 end_error = str(e)
637639 end_error_exec_info = sys.exc_info()
638640 logger.exception("")
639641 finally:
0 #!/bin/sh
0 #!/bin/bash
11 #
22 # Runs linting scripts over the local Synapse checkout
33 # isort - sorts import statements
66
77 set -e
88
9 if [ $# -ge 1 ]
10 then
11 files=$*
12 else
13 files="synapse tests scripts-dev scripts contrib synctl"
9 usage() {
10 echo
11 echo "Usage: $0 [-h] [-d] [paths...]"
12 echo
13 echo "-d"
14 echo " Lint files that have changed since the last git commit."
15 echo
16 echo " If paths are provided and this option is set, both provided paths and those"
17 echo " that have changed since the last commit will be linted."
18 echo
19 echo " If no paths are provided and this option is not set, all files will be linted."
20 echo
21 echo " Note that paths with a file extension that is not '.py' will be excluded."
22 echo "-h"
23 echo " Display this help text."
24 }
25
26 USING_DIFF=0
27 files=()
28
29 while getopts ":dh" opt; do
30 case $opt in
31 d)
32 USING_DIFF=1
33 ;;
34 h)
35 usage
36 exit
37 ;;
38 \?)
39 echo "ERROR: Invalid option: -$OPTARG" >&2
40 usage
41 exit
42 ;;
43 esac
44 done
45
46 # Strip any options from the command line arguments now that
47 # we've finished processing them
48 shift "$((OPTIND-1))"
49
50 if [ $USING_DIFF -eq 1 ]; then
51 # Check both staged and non-staged changes
52 for path in $(git diff HEAD --name-only); do
53 filename=$(basename "$path")
54 file_extension="${filename##*.}"
55
56 # If an extension is present, and it's something other than 'py',
57 # then ignore this file
58 if [[ -n ${file_extension+x} && $file_extension != "py" ]]; then
59 continue
60 fi
61
62 # Append this path to our list of files to lint
63 files+=("$path")
64 done
1465 fi
1566
16 echo "Linting these locations: $files"
17 isort $files
18 python3 -m black $files
67 # Append any remaining arguments as files to lint
68 files+=("$@")
69
70 if [[ $USING_DIFF -eq 1 ]]; then
71 # If we were asked to lint changed files, and no paths were found as a result...
72 if [ ${#files[@]} -eq 0 ]; then
73 # Then print and exit
74 echo "No files found to lint."
75 exit 0
76 fi
77 else
78 # If we were not asked to lint changed files, and no paths were found as a result,
79 # then lint everything!
80 if [[ -z ${files+x} ]]; then
81 # Lint all source code files and directories
82 files=("synapse" "tests" "scripts-dev" "scripts" "contrib" "synctl" "setup.py")
83 fi
84 fi
85
86 echo "Linting these paths: ${files[*]}"
87 echo
88
89 # Print out the commands being run
90 set -x
91
92 isort "${files[@]}"
93 python3 -m black "${files[@]}"
1994 ./scripts-dev/config-lint.sh
20 flake8 $files
95 flake8 "${files[@]}"
+0
-1
scripts-dev/sphinx_api_docs.sh less more
0 sphinx-apidoc -o docs/sphinx/ synapse/ -ef
0 [build_sphinx]
1 source-dir = docs/sphinx
2 build-dir = docs/build
3 all_files = 1
4
50 [trial]
61 test_suite = tests
72
1414 # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
1515 # See the License for the specific language governing permissions and
1616 # limitations under the License.
17
1817 import glob
1918 import os
20 from setuptools import setup, find_packages, Command
21 import sys
2219
20 from setuptools import Command, find_packages, setup
2321
2422 here = os.path.abspath(os.path.dirname(__file__))
2523
103101 "flake8",
104102 ]
105103
104 CONDITIONAL_REQUIREMENTS["mypy"] = ["mypy==0.790", "mypy-zope==0.2.8"]
105
106106 # Dependencies which are exclusively required by unit test code. This is
107107 # NOT a list of all modules that are necessary to run the unit tests.
108108 # Tests assume that all optional dependencies are installed.
0 from .sorteddict import (
1 SortedDict,
2 SortedKeysView,
3 SortedItemsView,
4 SortedValuesView,
5 )
0 from .sorteddict import SortedDict, SortedItemsView, SortedKeysView, SortedValuesView
1 from .sortedlist import SortedKeyList, SortedList, SortedListWithKey
62
73 __all__ = [
84 "SortedDict",
95 "SortedKeysView",
106 "SortedItemsView",
117 "SortedValuesView",
8 "SortedKeyList",
9 "SortedList",
10 "SortedListWithKey",
1211 ]
0 # stub for SortedList. This is an exact copy of
1 # https://github.com/grantjenks/python-sortedcontainers/blob/a419ffbd2b1c935b09f11f0971696e537fd0c510/sortedcontainers/sortedlist.pyi
2 # (from https://github.com/grantjenks/python-sortedcontainers/pull/107)
3
4 from typing import (
5 Any,
6 Callable,
7 Generic,
8 Iterable,
9 Iterator,
10 List,
11 MutableSequence,
12 Optional,
13 Sequence,
14 Tuple,
15 Type,
16 TypeVar,
17 Union,
18 overload,
19 )
20
21 _T = TypeVar("_T")
22 _SL = TypeVar("_SL", bound=SortedList)
23 _SKL = TypeVar("_SKL", bound=SortedKeyList)
24 _Key = Callable[[_T], Any]
25 _Repr = Callable[[], str]
26
27 def recursive_repr(fillvalue: str = ...) -> Callable[[_Repr], _Repr]: ...
28
29 class SortedList(MutableSequence[_T]):
30
31 DEFAULT_LOAD_FACTOR: int = ...
32 def __init__(
33 self, iterable: Optional[Iterable[_T]] = ..., key: Optional[_Key[_T]] = ...,
34 ): ...
35 # NB: currently mypy does not honour return type, see mypy #3307
36 @overload
37 def __new__(cls: Type[_SL], iterable: None, key: None) -> _SL: ...
38 @overload
39 def __new__(cls: Type[_SL], iterable: None, key: _Key[_T]) -> SortedKeyList[_T]: ...
40 @overload
41 def __new__(cls: Type[_SL], iterable: Iterable[_T], key: None) -> _SL: ...
42 @overload
43 def __new__(cls, iterable: Iterable[_T], key: _Key[_T]) -> SortedKeyList[_T]: ...
44 @property
45 def key(self) -> Optional[Callable[[_T], Any]]: ...
46 def _reset(self, load: int) -> None: ...
47 def clear(self) -> None: ...
48 def _clear(self) -> None: ...
49 def add(self, value: _T) -> None: ...
50 def _expand(self, pos: int) -> None: ...
51 def update(self, iterable: Iterable[_T]) -> None: ...
52 def _update(self, iterable: Iterable[_T]) -> None: ...
53 def discard(self, value: _T) -> None: ...
54 def remove(self, value: _T) -> None: ...
55 def _delete(self, pos: int, idx: int) -> None: ...
56 def _loc(self, pos: int, idx: int) -> int: ...
57 def _pos(self, idx: int) -> int: ...
58 def _build_index(self) -> None: ...
59 def __contains__(self, value: Any) -> bool: ...
60 def __delitem__(self, index: Union[int, slice]) -> None: ...
61 @overload
62 def __getitem__(self, index: int) -> _T: ...
63 @overload
64 def __getitem__(self, index: slice) -> List[_T]: ...
65 @overload
66 def _getitem(self, index: int) -> _T: ...
67 @overload
68 def _getitem(self, index: slice) -> List[_T]: ...
69 @overload
70 def __setitem__(self, index: int, value: _T) -> None: ...
71 @overload
72 def __setitem__(self, index: slice, value: Iterable[_T]) -> None: ...
73 def __iter__(self) -> Iterator[_T]: ...
74 def __reversed__(self) -> Iterator[_T]: ...
75 def __len__(self) -> int: ...
76 def reverse(self) -> None: ...
77 def islice(
78 self, start: Optional[int] = ..., stop: Optional[int] = ..., reverse=bool,
79 ) -> Iterator[_T]: ...
80 def _islice(
81 self, min_pos: int, min_idx: int, max_pos: int, max_idx: int, reverse: bool,
82 ) -> Iterator[_T]: ...
83 def irange(
84 self,
85 minimum: Optional[int] = ...,
86 maximum: Optional[int] = ...,
87 inclusive: Tuple[bool, bool] = ...,
88 reverse: bool = ...,
89 ) -> Iterator[_T]: ...
90 def bisect_left(self, value: _T) -> int: ...
91 def bisect_right(self, value: _T) -> int: ...
92 def bisect(self, value: _T) -> int: ...
93 def _bisect_right(self, value: _T) -> int: ...
94 def count(self, value: _T) -> int: ...
95 def copy(self: _SL) -> _SL: ...
96 def __copy__(self: _SL) -> _SL: ...
97 def append(self, value: _T) -> None: ...
98 def extend(self, values: Iterable[_T]) -> None: ...
99 def insert(self, index: int, value: _T) -> None: ...
100 def pop(self, index: int = ...) -> _T: ...
101 def index(
102 self, value: _T, start: Optional[int] = ..., stop: Optional[int] = ...
103 ) -> int: ...
104 def __add__(self: _SL, other: Iterable[_T]) -> _SL: ...
105 def __radd__(self: _SL, other: Iterable[_T]) -> _SL: ...
106 def __iadd__(self: _SL, other: Iterable[_T]) -> _SL: ...
107 def __mul__(self: _SL, num: int) -> _SL: ...
108 def __rmul__(self: _SL, num: int) -> _SL: ...
109 def __imul__(self: _SL, num: int) -> _SL: ...
110 def __eq__(self, other: Any) -> bool: ...
111 def __ne__(self, other: Any) -> bool: ...
112 def __lt__(self, other: Sequence[_T]) -> bool: ...
113 def __gt__(self, other: Sequence[_T]) -> bool: ...
114 def __le__(self, other: Sequence[_T]) -> bool: ...
115 def __ge__(self, other: Sequence[_T]) -> bool: ...
116 def __repr__(self) -> str: ...
117 def _check(self) -> None: ...
118
119 class SortedKeyList(SortedList[_T]):
120 def __init__(
121 self, iterable: Optional[Iterable[_T]] = ..., key: _Key[_T] = ...
122 ) -> None: ...
123 def __new__(
124 cls, iterable: Optional[Iterable[_T]] = ..., key: _Key[_T] = ...
125 ) -> SortedKeyList[_T]: ...
126 @property
127 def key(self) -> Callable[[_T], Any]: ...
128 def clear(self) -> None: ...
129 def _clear(self) -> None: ...
130 def add(self, value: _T) -> None: ...
131 def _expand(self, pos: int) -> None: ...
132 def update(self, iterable: Iterable[_T]) -> None: ...
133 def _update(self, iterable: Iterable[_T]) -> None: ...
134 # NB: Must be T to be safely passed to self.func, yet base class imposes Any
135 def __contains__(self, value: _T) -> bool: ... # type: ignore
136 def discard(self, value: _T) -> None: ...
137 def remove(self, value: _T) -> None: ...
138 def _delete(self, pos: int, idx: int) -> None: ...
139 def irange(
140 self,
141 minimum: Optional[int] = ...,
142 maximum: Optional[int] = ...,
143 inclusive: Tuple[bool, bool] = ...,
144 reverse: bool = ...,
145 ): ...
146 def irange_key(
147 self,
148 min_key: Optional[Any] = ...,
149 max_key: Optional[Any] = ...,
150 inclusive: Tuple[bool, bool] = ...,
151 reserve: bool = ...,
152 ): ...
153 def bisect_left(self, value: _T) -> int: ...
154 def bisect_right(self, value: _T) -> int: ...
155 def bisect(self, value: _T) -> int: ...
156 def bisect_key_left(self, key: Any) -> int: ...
157 def _bisect_key_left(self, key: Any) -> int: ...
158 def bisect_key_right(self, key: Any) -> int: ...
159 def _bisect_key_right(self, key: Any) -> int: ...
160 def bisect_key(self, key: Any) -> int: ...
161 def count(self, value: _T) -> int: ...
162 def copy(self: _SKL) -> _SKL: ...
163 def __copy__(self: _SKL) -> _SKL: ...
164 def index(
165 self, value: _T, start: Optional[int] = ..., stop: Optional[int] = ...
166 ) -> int: ...
167 def __add__(self: _SKL, other: Iterable[_T]) -> _SKL: ...
168 def __radd__(self: _SKL, other: Iterable[_T]) -> _SKL: ...
169 def __iadd__(self: _SKL, other: Iterable[_T]) -> _SKL: ...
170 def __mul__(self: _SKL, num: int) -> _SKL: ...
171 def __rmul__(self: _SKL, num: int) -> _SKL: ...
172 def __imul__(self: _SKL, num: int) -> _SKL: ...
173 def __repr__(self) -> str: ...
174 def _check(self) -> None: ...
175
176 SortedListWithKey = SortedKeyList
1515 """Contains *incomplete* type hints for txredisapi.
1616 """
1717
18 from typing import List, Optional, Union
18 from typing import List, Optional, Union, Type
1919
2020 class RedisProtocol:
2121 def publish(self, channel: str, message: bytes): ...
4141
4242 class SubscriberFactory:
4343 def buildProtocol(self, addr): ...
44
45 class ConnectionHandler: ...
46
47 class RedisFactory:
48 continueTrying: bool
49 handler: RedisProtocol
50 def __init__(
51 self,
52 uuid: str,
53 dbid: Optional[int],
54 poolsize: int,
55 isLazy: bool = False,
56 handler: Type = ConnectionHandler,
57 charset: str = "utf-8",
58 password: Optional[str] = None,
59 replyTimeout: Optional[int] = None,
60 convertNumbers: Optional[int] = True,
61 ): ...
4747 except ImportError:
4848 pass
4949
50 __version__ = "1.21.2"
50 __version__ = "1.22.1"
5151
5252 if bool(os.environ.get("SYNAPSE_TEST_PATCH_LOG_CONTEXTS", False)):
5353 # We import here so that we don't have to install a bunch of deps when
3333 from synapse.events import EventBase
3434 from synapse.logging import opentracing as opentracing
3535 from synapse.types import StateMap, UserID
36 from synapse.util.caches import register_cache
3736 from synapse.util.caches.lrucache import LruCache
3837 from synapse.util.metrics import Measure
3938
6968 self.store = hs.get_datastore()
7069 self.state = hs.get_state_handler()
7170
72 self.token_cache = LruCache(10000)
73 register_cache("cache", "token_cache", self.token_cache)
71 self.token_cache = LruCache(
72 10000, "token_cache"
73 ) # type: LruCache[str, Tuple[str, bool]]
7474
7575 self._auth_blocking = AuthBlocking(self.hs)
7676
154154 class RoomEncryptionAlgorithms:
155155 MEGOLM_V1_AES_SHA2 = "m.megolm.v1.aes-sha2"
156156 DEFAULT = MEGOLM_V1_AES_SHA2
157
158
159 class AccountDataTypes:
160 DIRECT = "m.direct"
161 IGNORED_USER_LIST = "m.ignored_user_list"
2727
2828 import synapse
2929 from synapse.app import check_bind_error
30 from synapse.app.phone_stats_home import start_phone_stats_home
3031 from synapse.config.server import ListenerConfig
3132 from synapse.crypto import context_factory
3233 from synapse.logging.context import PreserveLoggingContext
270271 hs.get_datastore().db_pool.start_profiling()
271272 hs.get_pusherpool().start()
272273
274 # Log when we start the shut down process.
275 hs.get_reactor().addSystemEventTrigger(
276 "before", "shutdown", logger.info, "Shutting down..."
277 )
278
273279 setup_sentry(hs)
274280 setup_sdnotify(hs)
281
282 # If background tasks are running on the main process, start collecting the
283 # phone home stats.
284 if hs.config.run_background_tasks:
285 start_phone_stats_home(hs)
275286
276287 # We now freeze all allocated objects in the hopes that (almost)
277288 # everything currently allocated are things that will be used for the
8888 user_id = args.user_id
8989 directory = args.output_directory
9090
91 res = await hs.get_handlers().admin_handler.export_user_data(
91 res = await hs.get_admin_handler().export_user_data(
9292 user_id, FileExfiltrationWriter(user_id, directory=directory)
9393 )
9494 print(res)
207207
208208 # Explicitly disable background processes
209209 config.update_user_directory = False
210 config.run_background_tasks = False
210211 config.start_pushers = False
211212 config.send_federation = False
212213
126126 from synapse.rest.key.v2 import KeyApiV2Resource
127127 from synapse.server import HomeServer, cache_in_self
128128 from synapse.storage.databases.main.censor_events import CensorEventsStore
129 from synapse.storage.databases.main.client_ips import ClientIpWorkerStore
129130 from synapse.storage.databases.main.media_repository import MediaRepositoryStore
131 from synapse.storage.databases.main.metrics import ServerMetricsStore
130132 from synapse.storage.databases.main.monthly_active_users import (
131133 MonthlyActiveUsersWorkerStore,
132134 )
133135 from synapse.storage.databases.main.presence import UserPresenceState
134136 from synapse.storage.databases.main.search import SearchWorkerStore
137 from synapse.storage.databases.main.stats import StatsStore
138 from synapse.storage.databases.main.transactions import TransactionWorkerStore
135139 from synapse.storage.databases.main.ui_auth import UIAuthWorkerStore
136140 from synapse.storage.databases.main.user_directory import UserDirectoryStore
137141 from synapse.types import ReadReceipt
453457 # FIXME(#3714): We need to add UserDirectoryStore as we write directly
454458 # rather than going via the correct worker.
455459 UserDirectoryStore,
460 StatsStore,
456461 UIAuthWorkerStore,
457462 SlavedDeviceInboxStore,
458463 SlavedDeviceStore,
462467 SlavedAccountDataStore,
463468 SlavedPusherStore,
464469 CensorEventsStore,
470 ClientIpWorkerStore,
465471 SlavedEventStore,
466472 SlavedKeyStore,
467473 RoomStore,
475481 SlavedFilteringStore,
476482 MonthlyActiveUsersWorkerStore,
477483 MediaRepositoryStore,
484 ServerMetricsStore,
478485 SearchWorkerStore,
486 TransactionWorkerStore,
479487 BaseSlavedStore,
480488 ):
481489 pass
781789 send_queue.process_rows_for_federation(self.federation_sender, rows)
782790 await self.update_token(token)
783791
784 # We also need to poke the federation sender when new events happen
785 elif stream_name == "events":
786 self.federation_sender.notify_new_events(token)
787
788792 # ... and when new receipts happen
789793 elif stream_name == ReceiptsStream.NAME:
790794 await self._on_new_receipts(rows)
1616
1717 import gc
1818 import logging
19 import math
2019 import os
21 import resource
2220 import sys
2321 from typing import Iterable
24
25 from prometheus_client import Gauge
2622
2723 from twisted.application import service
2824 from twisted.internet import defer, reactor
5955 from synapse.http.site import SynapseSite
6056 from synapse.logging.context import LoggingContext
6157 from synapse.metrics import METRICS_PREFIX, MetricsResource, RegistryProxy
62 from synapse.metrics.background_process_metrics import run_as_background_process
63 from synapse.module_api import ModuleApi
6458 from synapse.python_dependencies import check_requirements
6559 from synapse.replication.http import REPLICATION_PREFIX, ReplicationRestResource
6660 from synapse.replication.tcp.resource import ReplicationStreamProtocolFactory
110104
111105 additional_resources = listener_config.http_options.additional_resources
112106 logger.debug("Configuring additional resources: %r", additional_resources)
113 module_api = ModuleApi(self, self.get_auth_handler())
107 module_api = self.get_module_api()
114108 for path, resmodule in additional_resources.items():
115109 handler_cls, config = load_module(resmodule)
116110 handler = handler_cls(config, module_api)
333327 logger.warning("Unrecognized listener type: %s", listener.type)
334328
335329
336 # Gauges to expose monthly active user control metrics
337 current_mau_gauge = Gauge("synapse_admin_mau:current", "Current MAU")
338 current_mau_by_service_gauge = Gauge(
339 "synapse_admin_mau_current_mau_by_service",
340 "Current MAU by service",
341 ["app_service"],
342 )
343 max_mau_gauge = Gauge("synapse_admin_mau:max", "MAU Limit")
344 registered_reserved_users_mau_gauge = Gauge(
345 "synapse_admin_mau:registered_reserved_users",
346 "Registered users with reserved threepids",
347 )
348
349
350330 def setup(config_options):
351331 """
352332 Args:
388368 except UpgradeDatabaseException as e:
389369 quit_with_error("Failed to upgrade database: %s" % (e,))
390370
391 hs.setup_master()
392
393371 async def do_acme() -> bool:
394372 """
395373 Reprovision an ACME certificate, if it's required.
485463 return self._port.stopListening()
486464
487465
488 # Contains the list of processes we will be monitoring
489 # currently either 0 or 1
490 _stats_process = []
491
492
493 async def phone_stats_home(hs, stats, stats_process=_stats_process):
494 logger.info("Gathering stats for reporting")
495 now = int(hs.get_clock().time())
496 uptime = int(now - hs.start_time)
497 if uptime < 0:
498 uptime = 0
499
500 #
501 # Performance statistics. Keep this early in the function to maintain reliability of `test_performance_100` test.
502 #
503 old = stats_process[0]
504 new = (now, resource.getrusage(resource.RUSAGE_SELF))
505 stats_process[0] = new
506
507 # Get RSS in bytes
508 stats["memory_rss"] = new[1].ru_maxrss
509
510 # Get CPU time in % of a single core, not % of all cores
511 used_cpu_time = (new[1].ru_utime + new[1].ru_stime) - (
512 old[1].ru_utime + old[1].ru_stime
513 )
514 if used_cpu_time == 0 or new[0] == old[0]:
515 stats["cpu_average"] = 0
516 else:
517 stats["cpu_average"] = math.floor(used_cpu_time / (new[0] - old[0]) * 100)
518
519 #
520 # General statistics
521 #
522
523 stats["homeserver"] = hs.config.server_name
524 stats["server_context"] = hs.config.server_context
525 stats["timestamp"] = now
526 stats["uptime_seconds"] = uptime
527 version = sys.version_info
528 stats["python_version"] = "{}.{}.{}".format(
529 version.major, version.minor, version.micro
530 )
531 stats["total_users"] = await hs.get_datastore().count_all_users()
532
533 total_nonbridged_users = await hs.get_datastore().count_nonbridged_users()
534 stats["total_nonbridged_users"] = total_nonbridged_users
535
536 daily_user_type_results = await hs.get_datastore().count_daily_user_type()
537 for name, count in daily_user_type_results.items():
538 stats["daily_user_type_" + name] = count
539
540 room_count = await hs.get_datastore().get_room_count()
541 stats["total_room_count"] = room_count
542
543 stats["daily_active_users"] = await hs.get_datastore().count_daily_users()
544 stats["monthly_active_users"] = await hs.get_datastore().count_monthly_users()
545 stats["daily_active_rooms"] = await hs.get_datastore().count_daily_active_rooms()
546 stats["daily_messages"] = await hs.get_datastore().count_daily_messages()
547
548 r30_results = await hs.get_datastore().count_r30_users()
549 for name, count in r30_results.items():
550 stats["r30_users_" + name] = count
551
552 daily_sent_messages = await hs.get_datastore().count_daily_sent_messages()
553 stats["daily_sent_messages"] = daily_sent_messages
554 stats["cache_factor"] = hs.config.caches.global_factor
555 stats["event_cache_size"] = hs.config.caches.event_cache_size
556
557 #
558 # Database version
559 #
560
561 # This only reports info about the *main* database.
562 stats["database_engine"] = hs.get_datastore().db_pool.engine.module.__name__
563 stats["database_server_version"] = hs.get_datastore().db_pool.engine.server_version
564
565 logger.info("Reporting stats to %s: %s" % (hs.config.report_stats_endpoint, stats))
566 try:
567 await hs.get_proxied_http_client().put_json(
568 hs.config.report_stats_endpoint, stats
569 )
570 except Exception as e:
571 logger.warning("Error reporting stats: %s", e)
572
573
574466 def run(hs):
575467 PROFILE_SYNAPSE = False
576468 if PROFILE_SYNAPSE:
595487
596488 ThreadPool._worker = profile(ThreadPool._worker)
597489 reactor.run = profile(reactor.run)
598
599 clock = hs.get_clock()
600
601 stats = {}
602
603 def performance_stats_init():
604 _stats_process.clear()
605 _stats_process.append(
606 (int(hs.get_clock().time()), resource.getrusage(resource.RUSAGE_SELF))
607 )
608
609 def start_phone_stats_home():
610 return run_as_background_process(
611 "phone_stats_home", phone_stats_home, hs, stats
612 )
613
614 def generate_user_daily_visit_stats():
615 return run_as_background_process(
616 "generate_user_daily_visits", hs.get_datastore().generate_user_daily_visits
617 )
618
619 # Rather than update on per session basis, batch up the requests.
620 # If you increase the loop period, the accuracy of user_daily_visits
621 # table will decrease
622 clock.looping_call(generate_user_daily_visit_stats, 5 * 60 * 1000)
623
624 # monthly active user limiting functionality
625 def reap_monthly_active_users():
626 return run_as_background_process(
627 "reap_monthly_active_users", hs.get_datastore().reap_monthly_active_users
628 )
629
630 clock.looping_call(reap_monthly_active_users, 1000 * 60 * 60)
631 reap_monthly_active_users()
632
633 async def generate_monthly_active_users():
634 current_mau_count = 0
635 current_mau_count_by_service = {}
636 reserved_users = ()
637 store = hs.get_datastore()
638 if hs.config.limit_usage_by_mau or hs.config.mau_stats_only:
639 current_mau_count = await store.get_monthly_active_count()
640 current_mau_count_by_service = (
641 await store.get_monthly_active_count_by_service()
642 )
643 reserved_users = await store.get_registered_reserved_users()
644 current_mau_gauge.set(float(current_mau_count))
645
646 for app_service, count in current_mau_count_by_service.items():
647 current_mau_by_service_gauge.labels(app_service).set(float(count))
648
649 registered_reserved_users_mau_gauge.set(float(len(reserved_users)))
650 max_mau_gauge.set(float(hs.config.max_mau_value))
651
652 def start_generate_monthly_active_users():
653 return run_as_background_process(
654 "generate_monthly_active_users", generate_monthly_active_users
655 )
656
657 start_generate_monthly_active_users()
658 if hs.config.limit_usage_by_mau or hs.config.mau_stats_only:
659 clock.looping_call(start_generate_monthly_active_users, 5 * 60 * 1000)
660 # End of monthly active user settings
661
662 if hs.config.report_stats:
663 logger.info("Scheduling stats reporting for 3 hour intervals")
664 clock.looping_call(start_phone_stats_home, 3 * 60 * 60 * 1000)
665
666 # We need to defer this init for the cases that we daemonize
667 # otherwise the process ID we get is that of the non-daemon process
668 clock.call_later(0, performance_stats_init)
669
670 # We wait 5 minutes to send the first set of stats as the server can
671 # be quite busy the first few minutes
672 clock.call_later(5 * 60, start_phone_stats_home)
673490
674491 _base.start_reactor(
675492 "synapse-homeserver",
0 # Copyright 2020 The Matrix.org Foundation C.I.C.
1 #
2 # Licensed under the Apache License, Version 2.0 (the "License");
3 # you may not use this file except in compliance with the License.
4 # You may obtain a copy of the License at
5 #
6 # http://www.apache.org/licenses/LICENSE-2.0
7 #
8 # Unless required by applicable law or agreed to in writing, software
9 # distributed under the License is distributed on an "AS IS" BASIS,
10 # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
11 # See the License for the specific language governing permissions and
12 # limitations under the License.
13 import logging
14 import math
15 import resource
16 import sys
17
18 from prometheus_client import Gauge
19
20 from synapse.metrics.background_process_metrics import wrap_as_background_process
21
22 logger = logging.getLogger("synapse.app.homeserver")
23
24 # Contains the list of processes we will be monitoring
25 # currently either 0 or 1
26 _stats_process = []
27
28 # Gauges to expose monthly active user control metrics
29 current_mau_gauge = Gauge("synapse_admin_mau:current", "Current MAU")
30 current_mau_by_service_gauge = Gauge(
31 "synapse_admin_mau_current_mau_by_service",
32 "Current MAU by service",
33 ["app_service"],
34 )
35 max_mau_gauge = Gauge("synapse_admin_mau:max", "MAU Limit")
36 registered_reserved_users_mau_gauge = Gauge(
37 "synapse_admin_mau:registered_reserved_users",
38 "Registered users with reserved threepids",
39 )
40
41
42 @wrap_as_background_process("phone_stats_home")
43 async def phone_stats_home(hs, stats, stats_process=_stats_process):
44 logger.info("Gathering stats for reporting")
45 now = int(hs.get_clock().time())
46 uptime = int(now - hs.start_time)
47 if uptime < 0:
48 uptime = 0
49
50 #
51 # Performance statistics. Keep this early in the function to maintain reliability of `test_performance_100` test.
52 #
53 old = stats_process[0]
54 new = (now, resource.getrusage(resource.RUSAGE_SELF))
55 stats_process[0] = new
56
57 # Get RSS in bytes
58 stats["memory_rss"] = new[1].ru_maxrss
59
60 # Get CPU time in % of a single core, not % of all cores
61 used_cpu_time = (new[1].ru_utime + new[1].ru_stime) - (
62 old[1].ru_utime + old[1].ru_stime
63 )
64 if used_cpu_time == 0 or new[0] == old[0]:
65 stats["cpu_average"] = 0
66 else:
67 stats["cpu_average"] = math.floor(used_cpu_time / (new[0] - old[0]) * 100)
68
69 #
70 # General statistics
71 #
72
73 stats["homeserver"] = hs.config.server_name
74 stats["server_context"] = hs.config.server_context
75 stats["timestamp"] = now
76 stats["uptime_seconds"] = uptime
77 version = sys.version_info
78 stats["python_version"] = "{}.{}.{}".format(
79 version.major, version.minor, version.micro
80 )
81 stats["total_users"] = await hs.get_datastore().count_all_users()
82
83 total_nonbridged_users = await hs.get_datastore().count_nonbridged_users()
84 stats["total_nonbridged_users"] = total_nonbridged_users
85
86 daily_user_type_results = await hs.get_datastore().count_daily_user_type()
87 for name, count in daily_user_type_results.items():
88 stats["daily_user_type_" + name] = count
89
90 room_count = await hs.get_datastore().get_room_count()
91 stats["total_room_count"] = room_count
92
93 stats["daily_active_users"] = await hs.get_datastore().count_daily_users()
94 stats["monthly_active_users"] = await hs.get_datastore().count_monthly_users()
95 stats["daily_active_rooms"] = await hs.get_datastore().count_daily_active_rooms()
96 stats["daily_messages"] = await hs.get_datastore().count_daily_messages()
97
98 r30_results = await hs.get_datastore().count_r30_users()
99 for name, count in r30_results.items():
100 stats["r30_users_" + name] = count
101
102 daily_sent_messages = await hs.get_datastore().count_daily_sent_messages()
103 stats["daily_sent_messages"] = daily_sent_messages
104 stats["cache_factor"] = hs.config.caches.global_factor
105 stats["event_cache_size"] = hs.config.caches.event_cache_size
106
107 #
108 # Database version
109 #
110
111 # This only reports info about the *main* database.
112 stats["database_engine"] = hs.get_datastore().db_pool.engine.module.__name__
113 stats["database_server_version"] = hs.get_datastore().db_pool.engine.server_version
114
115 #
116 # Logging configuration
117 #
118 synapse_logger = logging.getLogger("synapse")
119 log_level = synapse_logger.getEffectiveLevel()
120 stats["log_level"] = logging.getLevelName(log_level)
121
122 logger.info("Reporting stats to %s: %s" % (hs.config.report_stats_endpoint, stats))
123 try:
124 await hs.get_proxied_http_client().put_json(
125 hs.config.report_stats_endpoint, stats
126 )
127 except Exception as e:
128 logger.warning("Error reporting stats: %s", e)
129
130
131 def start_phone_stats_home(hs):
132 """
133 Start the background tasks which report phone home stats.
134 """
135 clock = hs.get_clock()
136
137 stats = {}
138
139 def performance_stats_init():
140 _stats_process.clear()
141 _stats_process.append(
142 (int(hs.get_clock().time()), resource.getrusage(resource.RUSAGE_SELF))
143 )
144
145 # Rather than update on per session basis, batch up the requests.
146 # If you increase the loop period, the accuracy of user_daily_visits
147 # table will decrease
148 clock.looping_call(hs.get_datastore().generate_user_daily_visits, 5 * 60 * 1000)
149
150 # monthly active user limiting functionality
151 clock.looping_call(hs.get_datastore().reap_monthly_active_users, 1000 * 60 * 60)
152 hs.get_datastore().reap_monthly_active_users()
153
154 @wrap_as_background_process("generate_monthly_active_users")
155 async def generate_monthly_active_users():
156 current_mau_count = 0
157 current_mau_count_by_service = {}
158 reserved_users = ()
159 store = hs.get_datastore()
160 if hs.config.limit_usage_by_mau or hs.config.mau_stats_only:
161 current_mau_count = await store.get_monthly_active_count()
162 current_mau_count_by_service = (
163 await store.get_monthly_active_count_by_service()
164 )
165 reserved_users = await store.get_registered_reserved_users()
166 current_mau_gauge.set(float(current_mau_count))
167
168 for app_service, count in current_mau_count_by_service.items():
169 current_mau_by_service_gauge.labels(app_service).set(float(count))
170
171 registered_reserved_users_mau_gauge.set(float(len(reserved_users)))
172 max_mau_gauge.set(float(hs.config.max_mau_value))
173
174 if hs.config.limit_usage_by_mau or hs.config.mau_stats_only:
175 generate_monthly_active_users()
176 clock.looping_call(generate_monthly_active_users, 5 * 60 * 1000)
177 # End of monthly active user settings
178
179 if hs.config.report_stats:
180 logger.info("Scheduling stats reporting for 3 hour intervals")
181 clock.looping_call(phone_stats_home, 3 * 60 * 60 * 1000, hs, stats)
182
183 # We need to defer this init for the cases that we daemonize
184 # otherwise the process ID we get is that of the non-daemon process
185 clock.call_later(0, performance_stats_init)
186
187 # We wait 5 minutes to send the first set of stats as the server can
188 # be quite busy the first few minutes
189 clock.call_later(5 * 60, phone_stats_home, hs, stats)
1313 # limitations under the License.
1414 import logging
1515 import re
16 from typing import TYPE_CHECKING
16 from typing import TYPE_CHECKING, Iterable, List, Match, Optional
1717
1818 from synapse.api.constants import EventTypes
19 from synapse.appservice.api import ApplicationServiceApi
20 from synapse.types import GroupID, get_domain_from_id
21 from synapse.util.caches.descriptors import cached
19 from synapse.events import EventBase
20 from synapse.types import GroupID, JsonDict, UserID, get_domain_from_id
21 from synapse.util.caches.descriptors import _CacheContext, cached
2222
2323 if TYPE_CHECKING:
24 from synapse.appservice.api import ApplicationServiceApi
2425 from synapse.storage.databases.main import DataStore
2526
2627 logger = logging.getLogger(__name__)
2930 class ApplicationServiceState:
3031 DOWN = "down"
3132 UP = "up"
32
33
34 class AppServiceTransaction:
35 """Represents an application service transaction."""
36
37 def __init__(self, service, id, events):
38 self.service = service
39 self.id = id
40 self.events = events
41
42 async def send(self, as_api: ApplicationServiceApi) -> bool:
43 """Sends this transaction using the provided AS API interface.
44
45 Args:
46 as_api: The API to use to send.
47 Returns:
48 True if the transaction was sent.
49 """
50 return await as_api.push_bulk(
51 service=self.service, events=self.events, txn_id=self.id
52 )
53
54 async def complete(self, store: "DataStore") -> None:
55 """Completes this transaction as successful.
56
57 Marks this transaction ID on the application service and removes the
58 transaction contents from the database.
59
60 Args:
61 store: The database store to operate on.
62 """
63 await store.complete_appservice_txn(service=self.service, txn_id=self.id)
6433
6534
6635 class ApplicationService:
9059 protocols=None,
9160 rate_limited=True,
9261 ip_range_whitelist=None,
62 supports_ephemeral=False,
9363 ):
9464 self.token = token
9565 self.url = (
10171 self.namespaces = self._check_namespaces(namespaces)
10272 self.id = id
10373 self.ip_range_whitelist = ip_range_whitelist
74 self.supports_ephemeral = supports_ephemeral
10475
10576 if "|" in self.id:
10677 raise Exception("application service ID cannot contain '|' character")
160131 raise ValueError("Expected string for 'regex' in ns '%s'" % ns)
161132 return namespaces
162133
163 def _matches_regex(self, test_string, namespace_key):
134 def _matches_regex(self, test_string: str, namespace_key: str) -> Optional[Match]:
164135 for regex_obj in self.namespaces[namespace_key]:
165136 if regex_obj["regex"].match(test_string):
166137 return regex_obj
167138 return None
168139
169 def _is_exclusive(self, ns_key, test_string):
140 def _is_exclusive(self, ns_key: str, test_string: str) -> bool:
170141 regex_obj = self._matches_regex(test_string, ns_key)
171142 if regex_obj:
172143 return regex_obj["exclusive"]
173144 return False
174145
175 async def _matches_user(self, event, store):
146 async def _matches_user(
147 self, event: Optional[EventBase], store: Optional["DataStore"] = None
148 ) -> bool:
176149 if not event:
177150 return False
178151
187160 if not store:
188161 return False
189162
190 does_match = await self._matches_user_in_member_list(event.room_id, store)
163 does_match = await self.matches_user_in_member_list(event.room_id, store)
191164 return does_match
192165
193166 @cached(num_args=1, cache_context=True)
194 async def _matches_user_in_member_list(self, room_id, store, cache_context):
167 async def matches_user_in_member_list(
168 self, room_id: str, store: "DataStore", cache_context: _CacheContext,
169 ) -> bool:
170 """Check if this service is interested a room based upon it's membership
171
172 Args:
173 room_id: The room to check.
174 store: The datastore to query.
175
176 Returns:
177 True if this service would like to know about this room.
178 """
195179 member_list = await store.get_users_in_room(
196180 room_id, on_invalidate=cache_context.invalidate
197181 )
202186 return True
203187 return False
204188
205 def _matches_room_id(self, event):
189 def _matches_room_id(self, event: EventBase) -> bool:
206190 if hasattr(event, "room_id"):
207191 return self.is_interested_in_room(event.room_id)
208192 return False
209193
210 async def _matches_aliases(self, event, store):
194 async def _matches_aliases(
195 self, event: EventBase, store: Optional["DataStore"] = None
196 ) -> bool:
211197 if not store or not event:
212198 return False
213199
217203 return True
218204 return False
219205
220 async def is_interested(self, event, store=None) -> bool:
206 async def is_interested(
207 self, event: EventBase, store: Optional["DataStore"] = None
208 ) -> bool:
221209 """Check if this service is interested in this event.
222210
223211 Args:
224 event(Event): The event to check.
225 store(DataStore)
212 event: The event to check.
213 store: The datastore to query.
214
226215 Returns:
227216 True if this service would like to know about this event.
228217 """
230219 if self._matches_room_id(event):
231220 return True
232221
222 # This will check the namespaces first before
223 # checking the store, so should be run before _matches_aliases
224 if await self._matches_user(event, store):
225 return True
226
227 # This will check the store, so should be run last
233228 if await self._matches_aliases(event, store):
234229 return True
235230
236 if await self._matches_user(event, store):
237 return True
238
239 return False
240
241 def is_interested_in_user(self, user_id):
231 return False
232
233 @cached(num_args=1)
234 async def is_interested_in_presence(
235 self, user_id: UserID, store: "DataStore"
236 ) -> bool:
237 """Check if this service is interested a user's presence
238
239 Args:
240 user_id: The user to check.
241 store: The datastore to query.
242
243 Returns:
244 True if this service would like to know about presence for this user.
245 """
246 # Find all the rooms the sender is in
247 if self.is_interested_in_user(user_id.to_string()):
248 return True
249 room_ids = await store.get_rooms_for_user(user_id.to_string())
250
251 # Then find out if the appservice is interested in any of those rooms
252 for room_id in room_ids:
253 if await self.matches_user_in_member_list(room_id, store):
254 return True
255 return False
256
257 def is_interested_in_user(self, user_id: str) -> bool:
242258 return (
243 self._matches_regex(user_id, ApplicationService.NS_USERS)
259 bool(self._matches_regex(user_id, ApplicationService.NS_USERS))
244260 or user_id == self.sender
245261 )
246262
247 def is_interested_in_alias(self, alias):
263 def is_interested_in_alias(self, alias: str) -> bool:
248264 return bool(self._matches_regex(alias, ApplicationService.NS_ALIASES))
249265
250 def is_interested_in_room(self, room_id):
266 def is_interested_in_room(self, room_id: str) -> bool:
251267 return bool(self._matches_regex(room_id, ApplicationService.NS_ROOMS))
252268
253 def is_exclusive_user(self, user_id):
269 def is_exclusive_user(self, user_id: str) -> bool:
254270 return (
255271 self._is_exclusive(ApplicationService.NS_USERS, user_id)
256272 or user_id == self.sender
257273 )
258274
259 def is_interested_in_protocol(self, protocol):
275 def is_interested_in_protocol(self, protocol: str) -> bool:
260276 return protocol in self.protocols
261277
262 def is_exclusive_alias(self, alias):
278 def is_exclusive_alias(self, alias: str) -> bool:
263279 return self._is_exclusive(ApplicationService.NS_ALIASES, alias)
264280
265 def is_exclusive_room(self, room_id):
281 def is_exclusive_room(self, room_id: str) -> bool:
266282 return self._is_exclusive(ApplicationService.NS_ROOMS, room_id)
267283
268284 def get_exclusive_user_regexes(self):
275291 if regex_obj["exclusive"]
276292 ]
277293
278 def get_groups_for_user(self, user_id):
294 def get_groups_for_user(self, user_id: str) -> Iterable[str]:
279295 """Get the groups that this user is associated with by this AS
280296
281297 Args:
282 user_id (str): The ID of the user.
283
284 Returns:
285 iterable[str]: an iterable that yields group_id strings.
298 user_id: The ID of the user.
299
300 Returns:
301 An iterable that yields group_id strings.
286302 """
287303 return (
288304 regex_obj["group_id"]
290306 if "group_id" in regex_obj and regex_obj["regex"].match(user_id)
291307 )
292308
293 def is_rate_limited(self):
309 def is_rate_limited(self) -> bool:
294310 return self.rate_limited
295311
296312 def __str__(self):
299315 dict_copy["token"] = "<redacted>"
300316 dict_copy["hs_token"] = "<redacted>"
301317 return "ApplicationService: %s" % (dict_copy,)
318
319
320 class AppServiceTransaction:
321 """Represents an application service transaction."""
322
323 def __init__(
324 self,
325 service: ApplicationService,
326 id: int,
327 events: List[EventBase],
328 ephemeral: List[JsonDict],
329 ):
330 self.service = service
331 self.id = id
332 self.events = events
333 self.ephemeral = ephemeral
334
335 async def send(self, as_api: "ApplicationServiceApi") -> bool:
336 """Sends this transaction using the provided AS API interface.
337
338 Args:
339 as_api: The API to use to send.
340 Returns:
341 True if the transaction was sent.
342 """
343 return await as_api.push_bulk(
344 service=self.service,
345 events=self.events,
346 ephemeral=self.ephemeral,
347 txn_id=self.id,
348 )
349
350 async def complete(self, store: "DataStore") -> None:
351 """Completes this transaction as successful.
352
353 Marks this transaction ID on the application service and removes the
354 transaction contents from the database.
355
356 Args:
357 store: The database store to operate on.
358 """
359 await store.complete_appservice_txn(service=self.service, txn_id=self.id)
1313 # limitations under the License.
1414 import logging
1515 import urllib
16 from typing import TYPE_CHECKING, Optional
16 from typing import TYPE_CHECKING, List, Optional, Tuple
1717
1818 from prometheus_client import Counter
1919
2020 from synapse.api.constants import EventTypes, ThirdPartyEntityKind
2121 from synapse.api.errors import CodeMessageException
22 from synapse.events import EventBase
2223 from synapse.events.utils import serialize_event
2324 from synapse.http.client import SimpleHttpClient
2425 from synapse.types import JsonDict, ThirdPartyInstanceID
9293
9394 self.protocol_meta_cache = ResponseCache(
9495 hs, "as_protocol_meta", timeout_ms=HOUR_IN_MS
95 )
96 ) # type: ResponseCache[Tuple[str, str]]
9697
9798 async def query_user(self, service, user_id):
9899 if service.url is None:
200201 key = (service.id, protocol)
201202 return await self.protocol_meta_cache.wrap(key, _get)
202203
203 async def push_bulk(self, service, events, txn_id=None):
204 async def push_bulk(
205 self,
206 service: "ApplicationService",
207 events: List[EventBase],
208 ephemeral: List[JsonDict],
209 txn_id: Optional[int] = None,
210 ):
204211 if service.url is None:
205212 return True
206213
210217 logger.warning(
211218 "push_bulk: Missing txn ID sending events to %s", service.url
212219 )
213 txn_id = str(0)
214 txn_id = str(txn_id)
215
216 uri = service.url + ("/transactions/%s" % urllib.parse.quote(txn_id))
220 txn_id = 0
221
222 uri = service.url + ("/transactions/%s" % urllib.parse.quote(str(txn_id)))
223
224 # Never send ephemeral events to appservices that do not support it
225 if service.supports_ephemeral:
226 body = {"events": events, "de.sorunome.msc2409.ephemeral": ephemeral}
227 else:
228 body = {"events": events}
229
217230 try:
218231 await self.put_json(
219 uri=uri,
220 json_body={"events": events},
221 args={"access_token": service.hs_token},
232 uri=uri, json_body=body, args={"access_token": service.hs_token},
222233 )
223234 sent_transactions_counter.labels(service.id).inc()
224235 sent_events_counter.labels(service.id).inc(len(events))
4848 components.
4949 """
5050 import logging
51
52 from synapse.appservice import ApplicationServiceState
51 from typing import List
52
53 from synapse.appservice import ApplicationService, ApplicationServiceState
54 from synapse.events import EventBase
5355 from synapse.logging.context import run_in_background
5456 from synapse.metrics.background_process_metrics import run_as_background_process
57 from synapse.types import JsonDict
5558
5659 logger = logging.getLogger(__name__)
60
61
62 # Maximum number of events to provide in an AS transaction.
63 MAX_PERSISTENT_EVENTS_PER_TRANSACTION = 100
64
65 # Maximum number of ephemeral events to provide in an AS transaction.
66 MAX_EPHEMERAL_EVENTS_PER_TRANSACTION = 100
5767
5868
5969 class ApplicationServiceScheduler:
8191 for service in services:
8292 self.txn_ctrl.start_recoverer(service)
8393
84 def submit_event_for_as(self, service, event):
85 self.queuer.enqueue(service, event)
94 def submit_event_for_as(self, service: ApplicationService, event: EventBase):
95 self.queuer.enqueue_event(service, event)
96
97 def submit_ephemeral_events_for_as(
98 self, service: ApplicationService, events: List[JsonDict]
99 ):
100 self.queuer.enqueue_ephemeral(service, events)
86101
87102
88103 class _ServiceQueuer:
95110
96111 def __init__(self, txn_ctrl, clock):
97112 self.queued_events = {} # dict of {service_id: [events]}
113 self.queued_ephemeral = {} # dict of {service_id: [events]}
98114
99115 # the appservices which currently have a transaction in flight
100116 self.requests_in_flight = set()
101117 self.txn_ctrl = txn_ctrl
102118 self.clock = clock
103119
104 def enqueue(self, service, event):
105 self.queued_events.setdefault(service.id, []).append(event)
106
120 def _start_background_request(self, service):
107121 # start a sender for this appservice if we don't already have one
108
109122 if service.id in self.requests_in_flight:
110123 return
111124
113126 "as-sender-%s" % (service.id,), self._send_request, service
114127 )
115128
116 async def _send_request(self, service):
129 def enqueue_event(self, service: ApplicationService, event: EventBase):
130 self.queued_events.setdefault(service.id, []).append(event)
131 self._start_background_request(service)
132
133 def enqueue_ephemeral(self, service: ApplicationService, events: List[JsonDict]):
134 self.queued_ephemeral.setdefault(service.id, []).extend(events)
135 self._start_background_request(service)
136
137 async def _send_request(self, service: ApplicationService):
117138 # sanity-check: we shouldn't get here if this service already has a sender
118139 # running.
119140 assert service.id not in self.requests_in_flight
121142 self.requests_in_flight.add(service.id)
122143 try:
123144 while True:
124 events = self.queued_events.pop(service.id, [])
125 if not events:
145 all_events = self.queued_events.get(service.id, [])
146 events = all_events[:MAX_PERSISTENT_EVENTS_PER_TRANSACTION]
147 del all_events[:MAX_PERSISTENT_EVENTS_PER_TRANSACTION]
148
149 all_events_ephemeral = self.queued_ephemeral.get(service.id, [])
150 ephemeral = all_events_ephemeral[:MAX_EPHEMERAL_EVENTS_PER_TRANSACTION]
151 del all_events_ephemeral[:MAX_EPHEMERAL_EVENTS_PER_TRANSACTION]
152
153 if not events and not ephemeral:
126154 return
155
127156 try:
128 await self.txn_ctrl.send(service, events)
157 await self.txn_ctrl.send(service, events, ephemeral)
129158 except Exception:
130159 logger.exception("AS request failed")
131160 finally:
157186 # for UTs
158187 self.RECOVERER_CLASS = _Recoverer
159188
160 async def send(self, service, events):
189 async def send(
190 self,
191 service: ApplicationService,
192 events: List[EventBase],
193 ephemeral: List[JsonDict] = [],
194 ):
161195 try:
162 txn = await self.store.create_appservice_txn(service=service, events=events)
196 txn = await self.store.create_appservice_txn(
197 service=service, events=events, ephemeral=ephemeral
198 )
163199 service_is_up = await self._is_service_up(service)
164200 if service_is_up:
165201 sent = await txn.send(self.as_api)
203239 recoverer.recover()
204240 logger.info("Now %i active recoverers", len(self.recoverers))
205241
206 async def _is_service_up(self, service):
242 async def _is_service_up(self, service: ApplicationService) -> bool:
207243 state = await self.store.get_appservice_state(service)
208244 return state == ApplicationServiceState.UP or state is None
209245
159159 if as_info.get("ip_range_whitelist"):
160160 ip_range_whitelist = IPSet(as_info.get("ip_range_whitelist"))
161161
162 supports_ephemeral = as_info.get("de.sorunome.msc2409.push_ephemeral", False)
163
162164 return ApplicationService(
163165 token=as_info["as_token"],
164166 hostname=hostname,
167169 hs_token=as_info["hs_token"],
168170 sender=user_id,
169171 id=as_info["id"],
172 supports_ephemeral=supports_ephemeral,
170173 protocols=protocols,
171174 rate_limited=rate_limited,
172175 ip_range_whitelist=ip_range_whitelist,
5555 self.oidc_userinfo_endpoint = oidc_config.get("userinfo_endpoint")
5656 self.oidc_jwks_uri = oidc_config.get("jwks_uri")
5757 self.oidc_skip_verification = oidc_config.get("skip_verification", False)
58 self.oidc_user_profile_method = oidc_config.get("user_profile_method", "auto")
5859 self.oidc_allow_existing_users = oidc_config.get("allow_existing_users", False)
5960
6061 ump_config = oidc_config.get("user_mapping_provider", {})
158159 #
159160 #skip_verification: true
160161
162 # Whether to fetch the user profile from the userinfo endpoint. Valid
163 # values are: "auto" or "userinfo_endpoint".
164 #
165 # Defaults to "auto", which fetches the userinfo endpoint if "openid" is included
166 # in `scopes`. Uncomment the following to always fetch the userinfo endpoint.
167 #
168 #user_profile_method: "userinfo_endpoint"
169
161170 # Uncomment to allow a user logging in via OIDC to match a pre-existing account instead
162171 # of failing. This could be used if switching from password logins to OIDC. Defaults to false.
163172 #
9999 "media_instance_running_background_jobs",
100100 )
101101
102 self.max_upload_size = self.parse_size(config.get("max_upload_size", "10M"))
102 self.max_upload_size = self.parse_size(config.get("max_upload_size", "50M"))
103103 self.max_image_pixels = self.parse_size(config.get("max_image_pixels", "32M"))
104104 self.max_spider_size = self.parse_size(config.get("max_spider_size", "10M"))
105105
241241
242242 # The largest allowed upload size in bytes
243243 #
244 #max_upload_size: 10M
244 #max_upload_size: 50M
245245
246246 # Maximum number of pixels that will be thumbnailed
247247 #
3838 # in the list.
3939 DEFAULT_BIND_ADDRESSES = ["::", "0.0.0.0"]
4040
41 DEFAULT_ROOM_VERSION = "5"
41 DEFAULT_ROOM_VERSION = "6"
4242
4343 ROOM_COMPLEXITY_TOO_GREAT = (
4444 "Your homeserver is unable to join rooms this large or complex. "
1717 import warnings
1818 from datetime import datetime
1919 from hashlib import sha256
20 from typing import List
20 from typing import List, Optional
2121
2222 from unpaddedbase64 import encode_base64
2323
176176 "use_insecure_ssl_client_just_for_testing_do_not_use"
177177 )
178178
179 self.tls_certificate = None
180 self.tls_private_key = None
179 self.tls_certificate = None # type: Optional[crypto.X509]
180 self.tls_private_key = None # type: Optional[crypto.PKey]
181181
182182 def is_disk_cert_valid(self, allow_self_signed=True):
183183 """
225225 days_remaining = (expires_on - now).days
226226 return days_remaining
227227
228 def read_certificate_from_disk(self, require_cert_and_key):
228 def read_certificate_from_disk(self, require_cert_and_key: bool):
229229 """
230230 Read the certificates and private key from disk.
231231
232232 Args:
233 require_cert_and_key (bool): set to True to throw an error if the certificate
233 require_cert_and_key: set to True to throw an error if the certificate
234234 and key file are not given
235235 """
236236 if require_cert_and_key:
478478 }
479479 )
480480
481 def read_tls_certificate(self):
481 def read_tls_certificate(self) -> crypto.X509:
482482 """Reads the TLS certificate from the configured file, and returns it
483483
484484 Also checks if it is self-signed, and warns if so
485485
486486 Returns:
487 OpenSSL.crypto.X509: the certificate
487 The certificate
488488 """
489489 cert_path = self.tls_certificate_file
490490 logger.info("Loading TLS certificate from %s", cert_path)
503503
504504 return cert
505505
506 def read_tls_private_key(self):
506 def read_tls_private_key(self) -> crypto.PKey:
507507 """Reads the TLS private key from the configured file, and returns it
508508
509509 Returns:
510 OpenSSL.crypto.PKey: the private key
510 The private key
511511 """
512512 private_key_path = self.tls_private_key_file
513513 logger.info("Loading TLS key from %s", private_key_path)
131131
132132 self.events_shard_config = ShardedWorkerHandlingConfig(self.writers.events)
133133
134 # Whether this worker should run background tasks or not.
135 #
136 # As a note for developers, the background tasks guarded by this should
137 # be able to run on only a single instance (meaning that they don't
138 # depend on any in-memory state of a particular worker).
139 #
140 # No effort is made to ensure only a single instance of these tasks is
141 # running.
142 background_tasks_instance = config.get("run_background_tasks_on") or "master"
143 self.run_background_tasks = (
144 self.worker_name is None and background_tasks_instance == "master"
145 ) or self.worker_name == background_tasks_instance
146
134147 def generate_config_section(self, config_dir_path, server_name, **kwargs):
135148 return """\
136149 ## Workers ##
166179 #stream_writers:
167180 # events: worker1
168181 # typing: worker1
182
183 # The worker that is used to run background tasks (e.g. cleaning up expired
184 # data). If not provided this defaults to the main process.
185 #
186 #run_background_tasks_on: worker1
169187 """
170188
171189 def read_arguments(self, args):
445445
446446 if room_version_obj.event_format == EventFormatVersions.V1:
447447 redacter_domain = get_domain_from_id(event.event_id)
448 if not isinstance(event.redacts, str):
449 return False
448450 redactee_domain = get_domain_from_id(event.redacts)
449451 if redacter_domain == redactee_domain:
450452 return True
9696
9797
9898 class _EventInternalMetadata:
99 __slots__ = ["_dict"]
99 __slots__ = ["_dict", "stream_ordering"]
100100
101101 def __init__(self, internal_metadata_dict: JsonDict):
102102 # we have to copy the dict, because it turns out that the same dict is
103103 # reused. TODO: fix that
104104 self._dict = dict(internal_metadata_dict)
105
106 # the stream ordering of this event. None, until it has been persisted.
107 self.stream_ordering = None # type: Optional[int]
105108
106109 outlier = DictProperty("outlier") # type: bool
107110 out_of_band_membership = DictProperty("out_of_band_membership") # type: bool
112115 redacted = DictProperty("redacted") # type: bool
113116 txn_id = DictProperty("txn_id") # type: str
114117 token_id = DictProperty("token_id") # type: str
115 stream_ordering = DictProperty("stream_ordering") # type: int
116118
117119 # XXX: These are set by StreamWorkerStore._set_before_and_after.
118120 # I'm pretty sure that these are never persisted to the database, so shouldn't
309311 """
310312 return [e for e, _ in self.auth_events]
311313
314 def freeze(self):
315 """'Freeze' the event dict, so it cannot be modified by accident"""
316
317 # this will be a no-op if the event dict is already frozen.
318 self._dict = freeze(self._dict)
319
312320
313321 class FrozenEvent(EventBase):
314322 format_version = EventFormatVersions.V1 # All events of this type are V1
9696 def is_state(self):
9797 return self._state_key is not None
9898
99 async def build(self, prev_event_ids: List[str]) -> EventBase:
99 async def build(
100 self, prev_event_ids: List[str], auth_event_ids: Optional[List[str]],
101 ) -> EventBase:
100102 """Transform into a fully signed and hashed event
101103
102104 Args:
103105 prev_event_ids: The event IDs to use as the prev events
106 auth_event_ids: The event IDs to use as the auth events.
107 Should normally be set to None, which will cause them to be calculated
108 based on the room state at the prev_events.
104109
105110 Returns:
106111 The signed and hashed event.
107112 """
108
109 state_ids = await self._state.get_current_state_ids(
110 self.room_id, prev_event_ids
111 )
112 auth_ids = self._auth.compute_auth_events(self, state_ids)
113 if auth_event_ids is None:
114 state_ids = await self._state.get_current_state_ids(
115 self.room_id, prev_event_ids
116 )
117 auth_event_ids = self._auth.compute_auth_events(self, state_ids)
113118
114119 format_version = self.room_version.event_format
115120 if format_version == EventFormatVersions.V1:
116121 # The types of auth/prev events changes between event versions.
117122 auth_events = await self._store.add_event_hashes(
118 auth_ids
123 auth_event_ids
119124 ) # type: Union[List[str], List[Tuple[str, Dict[str, str]]]]
120125 prev_events = await self._store.add_event_hashes(
121126 prev_event_ids
122127 ) # type: Union[List[str], List[Tuple[str, Dict[str, str]]]]
123128 else:
124 auth_events = auth_ids
129 auth_events = auth_event_ids
125130 prev_events = prev_event_ids
126131
127132 old_depth = await self._store.get_max_depth_of(prev_event_ids)
1616 import inspect
1717 from typing import Any, Dict, List, Optional, Tuple
1818
19 from synapse.spam_checker_api import RegistrationBehaviour, SpamCheckerApi
19 from synapse.spam_checker_api import RegistrationBehaviour
2020 from synapse.types import Collection
2121
2222 MYPY = False
2323 if MYPY:
24 import synapse.events
2425 import synapse.server
2526
2627
2728 class SpamChecker:
2829 def __init__(self, hs: "synapse.server.HomeServer"):
2930 self.spam_checkers = [] # type: List[Any]
31 api = hs.get_module_api()
3032
3133 for module, config in hs.config.spam_checkers:
3234 # Older spam checkers don't accept the `api` argument, so we
3335 # try and detect support.
3436 spam_args = inspect.getfullargspec(module)
3537 if "api" in spam_args.args:
36 api = SpamCheckerApi(hs)
3738 self.spam_checkers.append(module(config=config, api=api))
3839 else:
3940 self.spam_checkers.append(module(config=config))
1212 # See the License for the specific language governing permissions and
1313 # limitations under the License.
1414
15 from typing import Callable, Union
16
1517 from synapse.events import EventBase
1618 from synapse.events.snapshot import EventContext
17 from synapse.types import Requester
19 from synapse.types import Requester, StateMap
1820
1921
2022 class ThirdPartyEventRules:
3739
3840 if module is not None:
3941 self.third_party_rules = module(
40 config=config, http_client=hs.get_simple_http_client()
42 config=config, module_api=hs.get_module_api(),
4143 )
4244
4345 async def check_event_allowed(
4446 self, event: EventBase, context: EventContext
45 ) -> bool:
47 ) -> Union[bool, dict]:
4648 """Check if a provided event should be allowed in the given context.
49
50 The module can return:
51 * True: the event is allowed.
52 * False: the event is not allowed, and should be rejected with M_FORBIDDEN.
53 * a dict: replacement event data.
4754
4855 Args:
4956 event: The event to be checked.
5057 context: The context of the event.
5158
5259 Returns:
53 True if the event should be allowed, False if not.
60 The result from the ThirdPartyRules module, as above
5461 """
5562 if self.third_party_rules is None:
5663 return True
5865 prev_state_ids = await context.get_prev_state_ids()
5966
6067 # Retrieve the state events from the database.
61 state_events = {}
62 for key, event_id in prev_state_ids.items():
63 state_events[key] = await self.store.get_event(event_id, allow_none=True)
68 events = await self.store.get_events(prev_state_ids.values())
69 state_events = {(ev.type, ev.state_key): ev for ev in events.values()}
6470
65 ret = await self.third_party_rules.check_event_allowed(event, state_events)
66 return ret
71 # Ensure that the event is frozen, to make sure that the module is not tempted
72 # to try to modify it. Any attempt to modify it at this point will invalidate
73 # the hashes and signatures.
74 event.freeze()
75
76 return await self.third_party_rules.check_event_allowed(event, state_events)
6777
6878 async def on_create_room(
6979 self, requester: Requester, config: dict, is_requester_admin: bool
105115 if self.third_party_rules is None:
106116 return True
107117
118 state_events = await self._get_state_map_for_room(room_id)
119
120 ret = await self.third_party_rules.check_threepid_can_be_invited(
121 medium, address, state_events
122 )
123 return ret
124
125 async def check_visibility_can_be_modified(
126 self, room_id: str, new_visibility: str
127 ) -> bool:
128 """Check if a room is allowed to be published to, or removed from, the public room
129 list.
130
131 Args:
132 room_id: The ID of the room.
133 new_visibility: The new visibility state. Either "public" or "private".
134
135 Returns:
136 True if the room's visibility can be modified, False if not.
137 """
138 if self.third_party_rules is None:
139 return True
140
141 check_func = getattr(
142 self.third_party_rules, "check_visibility_can_be_modified", None
143 )
144 if not check_func or not isinstance(check_func, Callable):
145 return True
146
147 state_events = await self._get_state_map_for_room(room_id)
148
149 return await check_func(room_id, state_events, new_visibility)
150
151 async def _get_state_map_for_room(self, room_id: str) -> StateMap[EventBase]:
152 """Given a room ID, return the state events of that room.
153
154 Args:
155 room_id: The ID of the room.
156
157 Returns:
158 A dict mapping (event type, state key) to state event.
159 """
108160 state_ids = await self.store.get_filtered_current_state_ids(room_id)
109161 room_state_events = await self.store.get_events(state_ids.values())
110162
112164 for key, event_id in state_ids.items():
113165 state_events[key] = room_state_events[event_id]
114166
115 ret = await self.third_party_rules.check_threepid_can_be_invited(
116 medium, address, state_events
117 )
118 return ret
167 return state_events
4646
4747 pruned_event = make_event_from_dict(
4848 pruned_event_dict, event.room_version, event.internal_metadata.get_dict()
49 )
50
51 # copy the internal fields
52 pruned_event.internal_metadata.stream_ordering = (
53 event.internal_metadata.stream_ordering
4954 )
5055
5156 # Mark the event as redacted
8282 Args:
8383 event (FrozenEvent): The event to validate.
8484 """
85 if not event.is_state():
86 raise SynapseError(code=400, msg="must be a state event")
87
8588 min_lifetime = event.content.get("min_lifetime")
8689 max_lifetime = event.content.get("max_lifetime")
8790
2121 Callable,
2222 Dict,
2323 List,
24 Match,
2524 Optional,
2625 Tuple,
2726 Union,
9998 super().__init__(hs)
10099
101100 self.auth = hs.get_auth()
102 self.handler = hs.get_handlers().federation_handler
101 self.handler = hs.get_federation_handler()
103102 self.state = hs.get_state_handler()
104103
105104 self.device_handler = hs.get_device_handler()
105
106 # Ensure the following handlers are loaded since they register callbacks
107 # with FederationHandlerRegistry.
108 hs.get_directory_handler()
109
106110 self._federation_ratelimiter = hs.get_federation_ratelimiter()
107111
108112 self._server_linearizer = Linearizer("fed_server")
111115 # We cache results for transaction with the same ID
112116 self._transaction_resp_cache = ResponseCache(
113117 hs, "fed_txn_handler", timeout_ms=30000
114 )
118 ) # type: ResponseCache[Tuple[str, str]]
115119
116120 self.transaction_actions = TransactionActions(self.store)
117121
119123
120124 # We cache responses to state queries, as they take a while and often
121125 # come in waves.
122 self._state_resp_cache = ResponseCache(hs, "state_resp", timeout_ms=30000)
126 self._state_resp_cache = ResponseCache(
127 hs, "state_resp", timeout_ms=30000
128 ) # type: ResponseCache[Tuple[str, str]]
123129 self._state_ids_resp_cache = ResponseCache(
124130 hs, "state_ids_resp", timeout_ms=30000
125 )
131 ) # type: ResponseCache[Tuple[str, str]]
126132
127133 self._federation_metrics_domains = (
128134 hs.get_config().federation.federation_metrics_domains
824830 return False
825831
826832
827 def _acl_entry_matches(server_name: str, acl_entry: str) -> Match:
833 def _acl_entry_matches(server_name: str, acl_entry: Any) -> bool:
828834 if not isinstance(acl_entry, str):
829835 logger.warning(
830836 "Ignoring non-str ACL entry '%s' (is %s)", acl_entry, type(acl_entry)
831837 )
832838 return False
833839 regex = glob_to_regex(acl_entry)
834 return regex.match(server_name)
840 return bool(regex.match(server_name))
835841
836842
837843 class FederationHandlerRegistry:
861867 self._edu_type_to_instance = {} # type: Dict[str, str]
862868
863869 def register_edu_handler(
864 self, edu_type: str, handler: Callable[[str, dict], Awaitable[None]]
870 self, edu_type: str, handler: Callable[[str, JsonDict], Awaitable[None]]
865871 ):
866872 """Sets the handler callable that will be used to handle an incoming
867873 federation EDU of the given type.
187187 for key in keys[:i]:
188188 del self.edus[key]
189189
190 def notify_new_events(self, current_id):
190 def notify_new_events(self, max_token):
191191 """As per FederationSender"""
192192 # We don't need to replicate this as it gets sent down a different
193193 # stream.
3939 events_processed_counter,
4040 )
4141 from synapse.metrics.background_process_metrics import run_as_background_process
42 from synapse.types import ReadReceipt
42 from synapse.types import ReadReceipt, RoomStreamToken
4343 from synapse.util.metrics import Measure, measure_func
4444
4545 logger = logging.getLogger(__name__)
153153 self._per_destination_queues[destination] = queue
154154 return queue
155155
156 def notify_new_events(self, current_id: int) -> None:
156 def notify_new_events(self, max_token: RoomStreamToken) -> None:
157157 """This gets called when we have some new events we might want to
158158 send out to other servers.
159159 """
160 # We just use the minimum stream ordering and ignore the vector clock
161 # component. This is safe to do as long as we *always* ignore the vector
162 # clock components.
163 current_id = max_token.stream
164
160165 self._last_poked_id = max(current_id, self._last_poked_id)
161166
162167 if self._is_processing:
295300
296301 sent_pdus_destination_dist_total.inc(len(destinations))
297302 sent_pdus_destination_dist_count.inc()
303
304 assert pdu.internal_metadata.stream_ordering
298305
299306 # track the fact that we have a PDU for these destinations,
300307 # to allow us to perform catch-up later on if the remote is unreachable
157157 # yet know if we have anything to catch up (None)
158158 self._pending_pdus.append(pdu)
159159 else:
160 assert pdu.internal_metadata.stream_ordering
160161 self._catchup_last_skipped = pdu.internal_metadata.stream_ordering
161162
162163 self.attempt_new_transaction()
360361 last_successful_stream_ordering = (
361362 final_pdu.internal_metadata.stream_ordering
362363 )
364 assert last_successful_stream_ordering
363365 await self._store.set_destination_last_successful_stream_ordering(
364366 self._destination, last_successful_stream_ordering
365367 )
1111 # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
1212 # See the License for the specific language governing permissions and
1313 # limitations under the License.
14
15 from .admin import AdminHandler
16 from .directory import DirectoryHandler
17 from .federation import FederationHandler
18 from .identity import IdentityHandler
19 from .search import SearchHandler
20
21
22 class Handlers:
23
24 """ Deprecated. A collection of handlers.
25
26 At some point most of the classes whose name ended "Handler" were
27 accessed through this class.
28
29 However this makes it painful to unit test the handlers and to run cut
30 down versions of synapse that only use specific handlers because using a
31 single handler required creating all of the handlers. So some of the
32 handlers have been lifted out of the Handlers object and are now accessed
33 directly through the homeserver object itself.
34
35 Any new handlers should follow the new pattern of being accessed through
36 the homeserver object and should not be added to the Handlers object.
37
38 The remaining handlers should be moved out of the handlers object.
39 """
40
41 def __init__(self, hs):
42 self.federation_handler = FederationHandler(hs)
43 self.directory_handler = DirectoryHandler(hs)
44 self.admin_handler = AdminHandler(hs)
45 self.identity_handler = IdentityHandler(hs)
46 self.search_handler = SearchHandler(hs)
1313 # limitations under the License.
1414
1515 import logging
16 from typing import TYPE_CHECKING, Optional
1617
1718 import synapse.state
1819 import synapse.storage
2021 from synapse.api.constants import EventTypes, Membership
2122 from synapse.api.ratelimiting import Ratelimiter
2223 from synapse.types import UserID
24
25 if TYPE_CHECKING:
26 from synapse.app.homeserver import HomeServer
2327
2428 logger = logging.getLogger(__name__)
2529
2933 Common base class for the event handlers.
3034 """
3135
32 def __init__(self, hs):
33 """
34 Args:
35 hs (synapse.server.HomeServer):
36 """
36 def __init__(self, hs: "HomeServer"):
3737 self.store = hs.get_datastore() # type: synapse.storage.DataStore
3838 self.auth = hs.get_auth()
3939 self.notifier = hs.get_notifier()
5555 clock=self.clock,
5656 rate_hz=self.hs.config.rc_admin_redaction.per_second,
5757 burst_count=self.hs.config.rc_admin_redaction.burst_count,
58 )
58 ) # type: Optional[Ratelimiter]
5959 else:
6060 self.admin_redaction_ratelimiter = None
6161
126126 if guest_access != "can_join":
127127 if context:
128128 current_state_ids = await context.get_current_state_ids()
129 current_state = await self.store.get_events(
129 current_state_dict = await self.store.get_events(
130130 list(current_state_ids.values())
131131 )
132 current_state = list(current_state_dict.values())
132133 else:
133 current_state = await self.state_handler.get_current_state(
134 current_state_map = await self.state_handler.get_current_state(
134135 event.room_id
135136 )
136
137 current_state = list(current_state.values())
137 current_state = list(current_state_map.values())
138138
139139 logger.info("maybe_kick_guest_users %r", current_state)
140140 await self.kick_guest_users(current_state)
1111 # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
1212 # See the License for the specific language governing permissions and
1313 # limitations under the License.
14 from typing import TYPE_CHECKING, List, Tuple
15
16 from synapse.types import JsonDict, UserID
17
18 if TYPE_CHECKING:
19 from synapse.app.homeserver import HomeServer
1420
1521
1622 class AccountDataEventSource:
17 def __init__(self, hs):
23 def __init__(self, hs: "HomeServer"):
1824 self.store = hs.get_datastore()
1925
20 def get_current_key(self, direction="f"):
26 def get_current_key(self, direction: str = "f") -> int:
2127 return self.store.get_max_account_data_stream_id()
2228
23 async def get_new_events(self, user, from_key, **kwargs):
29 async def get_new_events(
30 self, user: UserID, from_key: int, **kwargs
31 ) -> Tuple[List[JsonDict], int]:
2432 user_id = user.to_string()
2533 last_stream_id = from_key
2634
2121
2222 from synapse.api.errors import StoreError
2323 from synapse.logging.context import make_deferred_yieldable
24 from synapse.metrics.background_process_metrics import run_as_background_process
24 from synapse.metrics.background_process_metrics import wrap_as_background_process
2525 from synapse.types import UserID
2626 from synapse.util import stringutils
2727
6262 self._raw_from = email.utils.parseaddr(self._from_string)[1]
6363
6464 # Check the renewal emails to send and send them every 30min.
65 def send_emails():
66 # run as a background process to make sure that the database transactions
67 # have a logcontext to report to
68 return run_as_background_process(
69 "send_renewals", self._send_renewal_emails
70 )
71
72 self.clock.looping_call(send_emails, 30 * 60 * 1000)
73
65 if hs.config.run_background_tasks:
66 self.clock.looping_call(self._send_renewal_emails, 30 * 60 * 1000)
67
68 @wrap_as_background_process("send_renewals")
7469 async def _send_renewal_emails(self):
7570 """Gets the list of users whose account is expiring in the amount of time
7671 configured in the ``renew_at`` parameter from the ``account_validity``
1313 # limitations under the License.
1414
1515 import logging
16 from typing import Dict, List, Optional
1617
1718 from prometheus_client import Counter
1819
2021
2122 import synapse
2223 from synapse.api.constants import EventTypes
24 from synapse.appservice import ApplicationService
25 from synapse.events import EventBase
26 from synapse.handlers.presence import format_user_presence_state
2327 from synapse.logging.context import make_deferred_yieldable, run_in_background
2428 from synapse.metrics import (
2529 event_processing_loop_counter,
2630 event_processing_loop_room_count,
2731 )
2832 from synapse.metrics.background_process_metrics import run_as_background_process
33 from synapse.types import Collection, JsonDict, RoomStreamToken, UserID
2934 from synapse.util.metrics import Measure
3035
3136 logger = logging.getLogger(__name__)
4247 self.started_scheduler = False
4348 self.clock = hs.get_clock()
4449 self.notify_appservices = hs.config.notify_appservices
50 self.event_sources = hs.get_event_sources()
4551
4652 self.current_max = 0
4753 self.is_processing = False
4854
49 async def notify_interested_services(self, current_id):
55 async def notify_interested_services(self, max_token: RoomStreamToken):
5056 """Notifies (pushes) all application services interested in this event.
5157
5258 Pushing is done asynchronously, so this method won't block for any
5359 prolonged length of time.
54
55 Args:
56 current_id(int): The current maximum ID.
5760 """
61 # We just use the minimum stream ordering and ignore the vector clock
62 # component. This is safe to do as long as we *always* ignore the vector
63 # clock components.
64 current_id = max_token.stream
65
5866 services = self.store.get_app_services()
5967 if not services or not self.notify_appservices:
6068 return
7886 if not events:
7987 break
8088
81 events_by_room = {}
89 events_by_room = {} # type: Dict[str, List[EventBase]]
8290 for event in events:
8391 events_by_room.setdefault(event.room_id, []).append(event)
8492
156164 ).set(ts)
157165 finally:
158166 self.is_processing = False
167
168 async def notify_interested_services_ephemeral(
169 self, stream_key: str, new_token: Optional[int], users: Collection[UserID] = [],
170 ):
171 """This is called by the notifier in the background
172 when a ephemeral event handled by the homeserver.
173
174 This will determine which appservices
175 are interested in the event, and submit them.
176
177 Events will only be pushed to appservices
178 that have opted into ephemeral events
179
180 Args:
181 stream_key: The stream the event came from.
182 new_token: The latest stream token
183 users: The user(s) involved with the event.
184 """
185 services = [
186 service
187 for service in self.store.get_app_services()
188 if service.supports_ephemeral
189 ]
190 if not services or not self.notify_appservices:
191 return
192 logger.info("Checking interested services for %s" % (stream_key))
193 with Measure(self.clock, "notify_interested_services_ephemeral"):
194 for service in services:
195 # Only handle typing if we have the latest token
196 if stream_key == "typing_key" and new_token is not None:
197 events = await self._handle_typing(service, new_token)
198 if events:
199 self.scheduler.submit_ephemeral_events_for_as(service, events)
200 # We don't persist the token for typing_key for performance reasons
201 elif stream_key == "receipt_key":
202 events = await self._handle_receipts(service)
203 if events:
204 self.scheduler.submit_ephemeral_events_for_as(service, events)
205 await self.store.set_type_stream_id_for_appservice(
206 service, "read_receipt", new_token
207 )
208 elif stream_key == "presence_key":
209 events = await self._handle_presence(service, users)
210 if events:
211 self.scheduler.submit_ephemeral_events_for_as(service, events)
212 await self.store.set_type_stream_id_for_appservice(
213 service, "presence", new_token
214 )
215
216 async def _handle_typing(self, service: ApplicationService, new_token: int):
217 typing_source = self.event_sources.sources["typing"]
218 # Get the typing events from just before current
219 typing, _ = await typing_source.get_new_events_as(
220 service=service,
221 # For performance reasons, we don't persist the previous
222 # token in the DB and instead fetch the latest typing information
223 # for appservices.
224 from_key=new_token - 1,
225 )
226 return typing
227
228 async def _handle_receipts(self, service: ApplicationService):
229 from_key = await self.store.get_type_stream_id_for_appservice(
230 service, "read_receipt"
231 )
232 receipts_source = self.event_sources.sources["receipt"]
233 receipts, _ = await receipts_source.get_new_events_as(
234 service=service, from_key=from_key
235 )
236 return receipts
237
238 async def _handle_presence(
239 self, service: ApplicationService, users: Collection[UserID]
240 ) -> List[JsonDict]:
241 events = [] # type: List[JsonDict]
242 presence_source = self.event_sources.sources["presence"]
243 from_key = await self.store.get_type_stream_id_for_appservice(
244 service, "presence"
245 )
246 for user in users:
247 interested = await service.is_interested_in_presence(user, self.store)
248 if not interested:
249 continue
250 presence_events, _ = await presence_source.get_new_events(
251 user=user, service=service, from_key=from_key,
252 )
253 time_now = self.clock.time_msec()
254 events.extend(
255 {
256 "type": "m.presence",
257 "sender": event.user_id,
258 "content": format_user_presence_state(
259 event, time_now, include_user_id=False
260 ),
261 }
262 for event in presence_events
263 )
264
265 return events
159266
160267 async def query_user_exists(self, user_id):
161268 """Check if any application service knows this user_id exists.
219326
220327 async def get_3pe_protocols(self, only_protocol=None):
221328 services = self.store.get_app_services()
222 protocols = {}
329 protocols = {} # type: Dict[str, List[JsonDict]]
223330
224331 # Collect up all the individual protocol responses out of the ASes
225332 for s in services:
163163
164164 self.bcrypt_rounds = hs.config.bcrypt_rounds
165165
166 # we can't use hs.get_module_api() here, because to do so will create an
167 # import loop.
168 #
169 # TODO: refactor this class to separate the lower-level stuff that
170 # ModuleApi can use from the higher-level stuff that uses ModuleApi, as
171 # better way to break the loop
166172 account_handler = ModuleApi(hs, self)
173
167174 self.password_providers = [
168175 module(config=config, account_handler=account_handler)
169176 for module, config in hs.config.password_providers
211218 self._clock = self.hs.get_clock()
212219
213220 # Expire old UI auth sessions after a period of time.
214 if hs.config.worker_app is None:
221 if hs.config.run_background_tasks:
215222 self._clock.looping_call(
216223 run_as_background_process,
217224 5 * 60 * 1000,
10721079 if medium == "email":
10731080 address = canonicalise_email(address)
10741081
1075 identity_handler = self.hs.get_handlers().identity_handler
1082 identity_handler = self.hs.get_identity_handler()
10761083 result = await identity_handler.try_unbind_threepid(
10771084 user_id, {"medium": medium, "address": address, "id_server": id_server}
10781085 )
11141121 Whether self.hash(password) == stored_hash.
11151122 """
11161123
1117 def _do_validate_hash():
1124 def _do_validate_hash(checked_hash: bytes):
11181125 # Normalise the Unicode in the password
11191126 pw = unicodedata.normalize("NFKC", password)
11201127
11211128 return bcrypt.checkpw(
11221129 pw.encode("utf8") + self.hs.config.password_pepper.encode("utf8"),
1123 stored_hash,
1130 checked_hash,
11241131 )
11251132
11261133 if stored_hash:
11271134 if not isinstance(stored_hash, bytes):
11281135 stored_hash = stored_hash.encode("ascii")
11291136
1130 return await defer_to_thread(self.hs.get_reactor(), _do_validate_hash)
1137 return await defer_to_thread(
1138 self.hs.get_reactor(), _do_validate_hash, stored_hash
1139 )
11311140 else:
11321141 return False
11331142
1313 # See the License for the specific language governing permissions and
1414 # limitations under the License.
1515 import logging
16 from typing import Optional
16 from typing import TYPE_CHECKING, Optional
1717
1818 from synapse.api.errors import SynapseError
1919 from synapse.metrics.background_process_metrics import run_as_background_process
2121
2222 from ._base import BaseHandler
2323
24 if TYPE_CHECKING:
25 from synapse.app.homeserver import HomeServer
26
2427 logger = logging.getLogger(__name__)
2528
2629
2730 class DeactivateAccountHandler(BaseHandler):
2831 """Handler which deals with deactivating user accounts."""
2932
30 def __init__(self, hs):
33 def __init__(self, hs: "HomeServer"):
3134 super().__init__(hs)
3235 self.hs = hs
3336 self._auth_handler = hs.get_auth_handler()
3437 self._device_handler = hs.get_device_handler()
3538 self._room_member_handler = hs.get_room_member_handler()
36 self._identity_handler = hs.get_handlers().identity_handler
39 self._identity_handler = hs.get_identity_handler()
3740 self.user_directory_handler = hs.get_user_directory_handler()
3841
3942 # Flag that indicates whether the process to part users from rooms is running
4144
4245 # Start the user parter loop so it can resume parting users from rooms where
4346 # it left off (if it has work left to do).
44 if hs.config.worker_app is None:
47 if hs.config.run_background_tasks:
4548 hs.get_reactor().callWhenRunning(self._start_user_parting)
4649
4750 self._account_validity_enabled = hs.config.account_validity.enabled
136139
137140 return identity_server_supports_unbinding
138141
139 async def _reject_pending_invites_for_user(self, user_id: str):
142 async def _reject_pending_invites_for_user(self, user_id: str) -> None:
140143 """Reject pending invites addressed to a given user ID.
141144
142145 Args:
00 # -*- coding: utf-8 -*-
11 # Copyright 2016 OpenMarket Ltd
22 # Copyright 2019 New Vector Ltd
3 # Copyright 2019 The Matrix.org Foundation C.I.C.
3 # Copyright 2019,2020 The Matrix.org Foundation C.I.C.
44 #
55 # Licensed under the Apache License, Version 2.0 (the "License");
66 # you may not use this file except in compliance with the License.
1414 # See the License for the specific language governing permissions and
1515 # limitations under the License.
1616 import logging
17 from typing import Any, Dict, List, Optional
17 from typing import TYPE_CHECKING, Any, Dict, Iterable, List, Optional, Set, Tuple
1818
1919 from synapse.api import errors
2020 from synapse.api.constants import EventTypes
2828 from synapse.logging.opentracing import log_kv, set_tag, trace
2929 from synapse.metrics.background_process_metrics import run_as_background_process
3030 from synapse.types import (
31 Collection,
32 JsonDict,
3133 StreamToken,
34 UserID,
3235 get_domain_from_id,
3336 get_verify_key_from_cross_signing_key,
3437 )
4043
4144 from ._base import BaseHandler
4245
46 if TYPE_CHECKING:
47 from synapse.app.homeserver import HomeServer
48
4349 logger = logging.getLogger(__name__)
4450
4551 MAX_DEVICE_DISPLAY_NAME_LEN = 100
4652
4753
4854 class DeviceWorkerHandler(BaseHandler):
49 def __init__(self, hs):
55 def __init__(self, hs: "HomeServer"):
5056 super().__init__(hs)
5157
5258 self.hs = hs
104110
105111 @trace
106112 @measure_func("device.get_user_ids_changed")
107 async def get_user_ids_changed(self, user_id: str, from_token: StreamToken):
113 async def get_user_ids_changed(
114 self, user_id: str, from_token: StreamToken
115 ) -> JsonDict:
108116 """Get list of users that have had the devices updated, or have newly
109117 joined a room, that `user_id` may be interested in.
110118 """
220228 possibly_joined = possibly_changed & users_who_share_room
221229 possibly_left = (possibly_changed | possibly_left) - users_who_share_room
222230 else:
223 possibly_joined = []
224 possibly_left = []
231 possibly_joined = set()
232 possibly_left = set()
225233
226234 result = {"changed": list(possibly_joined), "left": list(possibly_left)}
227235
229237
230238 return result
231239
232 async def on_federation_query_user_devices(self, user_id):
240 async def on_federation_query_user_devices(self, user_id: str) -> JsonDict:
233241 stream_id, devices = await self.store.get_e2e_device_keys_for_federation_query(
234242 user_id
235243 )
248256
249257
250258 class DeviceHandler(DeviceWorkerHandler):
251 def __init__(self, hs):
259 def __init__(self, hs: "HomeServer"):
252260 super().__init__(hs)
253261
254262 self.federation_sender = hs.get_federation_sender()
263271
264272 hs.get_distributor().observe("user_left_room", self.user_left_room)
265273
266 def _check_device_name_length(self, name: str):
274 def _check_device_name_length(self, name: Optional[str]):
267275 """
268276 Checks whether a device name is longer than the maximum allowed length.
269277
282290 )
283291
284292 async def check_device_registered(
285 self, user_id, device_id, initial_device_display_name=None
286 ):
293 self,
294 user_id: str,
295 device_id: Optional[str],
296 initial_device_display_name: Optional[str] = None,
297 ) -> str:
287298 """
288299 If the given device has not been registered, register it with the
289300 supplied display name.
291302 If no device_id is supplied, we make one up.
292303
293304 Args:
294 user_id (str): @user:id
295 device_id (str | None): device id supplied by client
296 initial_device_display_name (str | None): device display name from
297 client
305 user_id: @user:id
306 device_id: device id supplied by client
307 initial_device_display_name: device display name from client
298308 Returns:
299 str: device id (generated if none was supplied)
309 device id (generated if none was supplied)
300310 """
301311
302312 self._check_device_name_length(initial_device_display_name)
315325 # times in case of a clash.
316326 attempts = 0
317327 while attempts < 5:
318 device_id = stringutils.random_string(10).upper()
328 new_device_id = stringutils.random_string(10).upper()
319329 new_device = await self.store.store_device(
320330 user_id=user_id,
321 device_id=device_id,
331 device_id=new_device_id,
322332 initial_device_display_name=initial_device_display_name,
323333 )
324334 if new_device:
325 await self.notify_device_update(user_id, [device_id])
326 return device_id
335 await self.notify_device_update(user_id, [new_device_id])
336 return new_device_id
327337 attempts += 1
328338
329339 raise errors.StoreError(500, "Couldn't generate a device ID.")
432442
433443 @trace
434444 @measure_func("notify_device_update")
435 async def notify_device_update(self, user_id, device_ids):
445 async def notify_device_update(
446 self, user_id: str, device_ids: Collection[str]
447 ) -> None:
436448 """Notify that a user's device(s) has changed. Pokes the notifier, and
437449 remote servers if the user is local.
438450 """
444456 user_id
445457 )
446458
447 hosts = set()
459 hosts = set() # type: Set[str]
448460 if self.hs.is_mine_id(user_id):
449461 hosts.update(get_domain_from_id(u) for u in users_who_share_room)
450462 hosts.discard(self.server_name)
496508
497509 self.notifier.on_new_event("device_list_key", position, users=[from_user_id])
498510
499 async def user_left_room(self, user, room_id):
511 async def user_left_room(self, user: UserID, room_id: str) -> None:
500512 user_id = user.to_string()
501513 room_ids = await self.store.get_rooms_for_user(user_id)
502514 if not room_ids:
504516 # receive device updates. Mark this in DB.
505517 await self.store.mark_remote_user_device_list_as_unsubscribed(user_id)
506518
507
508 def _update_device_from_client_ips(device, client_ips):
519 async def store_dehydrated_device(
520 self,
521 user_id: str,
522 device_data: JsonDict,
523 initial_device_display_name: Optional[str] = None,
524 ) -> str:
525 """Store a dehydrated device for a user. If the user had a previous
526 dehydrated device, it is removed.
527
528 Args:
529 user_id: the user that we are storing the device for
530 device_data: the dehydrated device information
531 initial_device_display_name: The display name to use for the device
532 Returns:
533 device id of the dehydrated device
534 """
535 device_id = await self.check_device_registered(
536 user_id, None, initial_device_display_name,
537 )
538 old_device_id = await self.store.store_dehydrated_device(
539 user_id, device_id, device_data
540 )
541 if old_device_id is not None:
542 await self.delete_device(user_id, old_device_id)
543 return device_id
544
545 async def get_dehydrated_device(
546 self, user_id: str
547 ) -> Optional[Tuple[str, JsonDict]]:
548 """Retrieve the information for a dehydrated device.
549
550 Args:
551 user_id: the user whose dehydrated device we are looking for
552 Returns:
553 a tuple whose first item is the device ID, and the second item is
554 the dehydrated device information
555 """
556 return await self.store.get_dehydrated_device(user_id)
557
558 async def rehydrate_device(
559 self, user_id: str, access_token: str, device_id: str
560 ) -> dict:
561 """Process a rehydration request from the user.
562
563 Args:
564 user_id: the user who is rehydrating the device
565 access_token: the access token used for the request
566 device_id: the ID of the device that will be rehydrated
567 Returns:
568 a dict containing {"success": True}
569 """
570 success = await self.store.remove_dehydrated_device(user_id, device_id)
571
572 if not success:
573 raise errors.NotFoundError()
574
575 # If the dehydrated device was successfully deleted (the device ID
576 # matched the stored dehydrated device), then modify the access
577 # token to use the dehydrated device's ID and copy the old device
578 # display name to the dehydrated device, and destroy the old device
579 # ID
580 old_device_id = await self.store.set_device_for_access_token(
581 access_token, device_id
582 )
583 old_device = await self.store.get_device(user_id, old_device_id)
584 await self.store.update_device(user_id, device_id, old_device["display_name"])
585 # can't call self.delete_device because that will clobber the
586 # access token so call the storage layer directly
587 await self.store.delete_device(user_id, old_device_id)
588 await self.store.delete_e2e_keys_by_device(
589 user_id=user_id, device_id=old_device_id
590 )
591
592 # tell everyone that the old device is gone and that the dehydrated
593 # device has a new display name
594 await self.notify_device_update(user_id, [old_device_id, device_id])
595
596 return {"success": True}
597
598
599 def _update_device_from_client_ips(
600 device: Dict[str, Any], client_ips: Dict[Tuple[str, str], Dict[str, Any]]
601 ) -> None:
509602 ip = client_ips.get((device["user_id"], device["device_id"]), {})
510603 device.update({"last_seen_ts": ip.get("last_seen"), "last_seen_ip": ip.get("ip")})
511604
513606 class DeviceListUpdater:
514607 "Handles incoming device list updates from federation and updates the DB"
515608
516 def __init__(self, hs, device_handler):
609 def __init__(self, hs: "HomeServer", device_handler: DeviceHandler):
517610 self.store = hs.get_datastore()
518611 self.federation = hs.get_federation_client()
519612 self.clock = hs.get_clock()
522615 self._remote_edu_linearizer = Linearizer(name="remote_device_list")
523616
524617 # user_id -> list of updates waiting to be handled.
525 self._pending_updates = {}
618 self._pending_updates = (
619 {}
620 ) # type: Dict[str, List[Tuple[str, str, Iterable[str], JsonDict]]]
526621
527622 # Recently seen stream ids. We don't bother keeping these in the DB,
528623 # but they're useful to have them about to reduce the number of spurious
545640 )
546641
547642 @trace
548 async def incoming_device_list_update(self, origin, edu_content):
643 async def incoming_device_list_update(
644 self, origin: str, edu_content: JsonDict
645 ) -> None:
549646 """Called on incoming device list update from federation. Responsible
550647 for parsing the EDU and adding to pending updates list.
551648 """
606703 await self._handle_device_updates(user_id)
607704
608705 @measure_func("_incoming_device_list_update")
609 async def _handle_device_updates(self, user_id):
706 async def _handle_device_updates(self, user_id: str) -> None:
610707 "Actually handle pending updates."
611708
612709 with (await self._remote_edu_linearizer.queue(user_id)):
654751 stream_id for _, stream_id, _, _ in pending_updates
655752 )
656753
657 async def _need_to_do_resync(self, user_id, updates):
754 async def _need_to_do_resync(
755 self, user_id: str, updates: Iterable[Tuple[str, str, Iterable[str], JsonDict]]
756 ) -> bool:
658757 """Given a list of updates for a user figure out if we need to do a full
659758 resync, or whether we have enough data that we can just apply the delta.
660759 """
685784 return False
686785
687786 @trace
688 async def _maybe_retry_device_resync(self):
787 async def _maybe_retry_device_resync(self) -> None:
689788 """Retry to resync device lists that are out of sync, except if another retry is
690789 in progress.
691790 """
728827
729828 async def user_device_resync(
730829 self, user_id: str, mark_failed_as_stale: bool = True
731 ) -> Optional[dict]:
830 ) -> Optional[JsonDict]:
732831 """Fetches all devices for a user and updates the device cache with them.
733832
734833 Args:
752851 # it later.
753852 await self.store.mark_remote_user_device_cache_as_stale(user_id)
754853
755 return
854 return None
756855 except (RequestSendFailed, HttpResponseException) as e:
757856 logger.warning(
758857 "Failed to handle device list update for %s: %s", user_id, e,
769868 # next time we get a device list update for this user_id.
770869 # This makes it more likely that the device lists will
771870 # eventually become consistent.
772 return
871 return None
773872 except FederationDeniedError as e:
774873 set_tag("error", True)
775874 log_kv({"reason": "FederationDeniedError"})
776875 logger.info(e)
777 return
876 return None
778877 except Exception as e:
779878 set_tag("error", True)
780879 log_kv(
787886 # it later.
788887 await self.store.mark_remote_user_device_cache_as_stale(user_id)
789888
790 return
889 return None
791890 log_kv({"result": result})
792891 stream_id = result["stream_id"]
793892 devices = result["devices"]
848947 user_id: str,
849948 master_key: Optional[Dict[str, Any]],
850949 self_signing_key: Optional[Dict[str, Any]],
851 ) -> list:
950 ) -> List[str]:
852951 """Process the given new master and self-signing key for the given remote user.
853952
854953 Args:
1313 # limitations under the License.
1414
1515 import logging
16 from typing import Any, Dict
16 from typing import TYPE_CHECKING, Any, Dict
1717
1818 from synapse.api.errors import SynapseError
1919 from synapse.logging.context import run_in_background
2323 set_tag,
2424 start_active_span,
2525 )
26 from synapse.types import UserID, get_domain_from_id
26 from synapse.types import JsonDict, UserID, get_domain_from_id
2727 from synapse.util import json_encoder
2828 from synapse.util.stringutils import random_string
2929
30 if TYPE_CHECKING:
31 from synapse.app.homeserver import HomeServer
32
33
3034 logger = logging.getLogger(__name__)
3135
3236
3337 class DeviceMessageHandler:
34 def __init__(self, hs):
38 def __init__(self, hs: "HomeServer"):
3539 """
3640 Args:
37 hs (synapse.server.HomeServer): server
41 hs: server
3842 """
3943 self.store = hs.get_datastore()
4044 self.notifier = hs.get_notifier()
4751
4852 self._device_list_updater = hs.get_device_handler().device_list_updater
4953
50 async def on_direct_to_device_edu(self, origin, content):
54 async def on_direct_to_device_edu(self, origin: str, content: JsonDict) -> None:
5155 local_messages = {}
5256 sender_user_id = content["sender"]
5357 if origin != get_domain_from_id(sender_user_id):
9498 message_type: str,
9599 sender_user_id: str,
96100 by_device: Dict[str, Dict[str, Any]],
97 ):
101 ) -> None:
98102 """Checks inbound device messages for unknown remote devices, and if
99103 found marks the remote cache for the user as stale.
100104 """
137141 self._device_list_updater.user_device_resync, sender_user_id
138142 )
139143
140 async def send_device_message(self, sender_user_id, message_type, messages):
144 async def send_device_message(
145 self,
146 sender_user_id: str,
147 message_type: str,
148 messages: Dict[str, Dict[str, JsonDict]],
149 ) -> None:
141150 set_tag("number_of_messages", len(messages))
142151 set_tag("sender", sender_user_id)
143152 local_messages = {}
144 remote_messages = {}
153 remote_messages = {} # type: Dict[str, Dict[str, Dict[str, JsonDict]]]
145154 for user_id, by_device in messages.items():
146155 # we use UserID.from_string to catch invalid user ids
147156 if self.is_mine(UserID.from_string(user_id)):
4545 self.config = hs.config
4646 self.enable_room_list_search = hs.config.enable_room_list_search
4747 self.require_membership = hs.config.require_membership_for_aliases
48 self.third_party_event_rules = hs.get_third_party_event_rules()
4849
4950 self.federation = hs.get_federation_client()
5051 hs.get_federation_registry().register_query_handler(
382383 """
383384 creator = await self.store.get_room_alias_creator(alias.to_string())
384385
385 if creator is not None and creator == user_id:
386 if creator == user_id:
386387 return True
387388
388389 # Resolve the alias to the corresponding room.
453454 # per alias creation rule?
454455 raise SynapseError(403, "Not allowed to publish room")
455456
457 # Check if publishing is blocked by a third party module
458 allowed_by_third_party_rules = await (
459 self.third_party_event_rules.check_visibility_can_be_modified(
460 room_id, visibility
461 )
462 )
463 if not allowed_by_third_party_rules:
464 raise SynapseError(403, "Not allowed to publish room")
465
456466 await self.store.set_room_is_public(room_id, making_public)
457467
458468 async def edit_published_appservice_room_list(
494494 else:
495495 log_kv(
496496 {"message": "Did not update one_time_keys", "reason": "no keys given"}
497 )
498 fallback_keys = keys.get("org.matrix.msc2732.fallback_keys", None)
499 if fallback_keys and isinstance(fallback_keys, dict):
500 log_kv(
501 {
502 "message": "Updating fallback_keys for device.",
503 "user_id": user_id,
504 "device_id": device_id,
505 }
506 )
507 await self.store.set_e2e_fallback_keys(user_id, device_id, fallback_keys)
508 elif fallback_keys:
509 log_kv({"message": "Did not update fallback_keys", "reason": "not a dict"})
510 else:
511 log_kv(
512 {"message": "Did not update fallback_keys", "reason": "no keys given"}
497513 )
498514
499515 # the device should have been registered already, but it may have been
15061506 event, context = await self.event_creation_handler.create_new_client_event(
15071507 builder=builder
15081508 )
1509 except AuthError as e:
1509 except SynapseError as e:
15101510 logger.warning("Failed to create join to %s because %s", room_id, e)
1511 raise e
1512
1513 event_allowed = await self.third_party_event_rules.check_event_allowed(
1514 event, context
1515 )
1516 if not event_allowed:
1517 logger.info("Creation of join %s forbidden by third-party rules", event)
1518 raise SynapseError(
1519 403, "This event is not allowed in this context", Codes.FORBIDDEN
1520 )
1511 raise
15211512
15221513 # The remote hasn't signed it yet, obviously. We'll do the full checks
15231514 # when we get the event back in `on_send_join_request`
15661557
15671558 context = await self._handle_new_event(origin, event)
15681559
1569 event_allowed = await self.third_party_event_rules.check_event_allowed(
1570 event, context
1571 )
1572 if not event_allowed:
1573 logger.info("Sending of join %s forbidden by third-party rules", event)
1574 raise SynapseError(
1575 403, "This event is not allowed in this context", Codes.FORBIDDEN
1576 )
1577
15781560 logger.debug(
15791561 "on_send_join_request: After _handle_new_event: %s, sigs: %s",
15801562 event.event_id,
17471729 builder=builder
17481730 )
17491731
1750 event_allowed = await self.third_party_event_rules.check_event_allowed(
1751 event, context
1752 )
1753 if not event_allowed:
1754 logger.warning("Creation of leave %s forbidden by third-party rules", event)
1755 raise SynapseError(
1756 403, "This event is not allowed in this context", Codes.FORBIDDEN
1757 )
1758
17591732 try:
17601733 # The remote hasn't signed it yet, obviously. We'll do the full checks
17611734 # when we get the event back in `on_send_leave_request`
17881761
17891762 event.internal_metadata.outlier = False
17901763
1791 context = await self._handle_new_event(origin, event)
1792
1793 event_allowed = await self.third_party_event_rules.check_event_allowed(
1794 event, context
1795 )
1796 if not event_allowed:
1797 logger.info("Sending of leave %s forbidden by third-party rules", event)
1798 raise SynapseError(
1799 403, "This event is not allowed in this context", Codes.FORBIDDEN
1800 )
1764 await self._handle_new_event(origin, event)
18011765
18021766 logger.debug(
18031767 "on_send_leave_request: After _handle_new_event: %s, sigs: %s",
26932657 builder=builder
26942658 )
26952659
2696 event_allowed = await self.third_party_event_rules.check_event_allowed(
2697 event, context
2698 )
2699 if not event_allowed:
2700 logger.info(
2701 "Creation of threepid invite %s forbidden by third-party rules",
2702 event,
2703 )
2704 raise SynapseError(
2705 403, "This event is not allowed in this context", Codes.FORBIDDEN
2706 )
2707
27082660 event, context = await self.add_display_name_to_third_party_invite(
27092661 room_version, event_dict, event, context
27102662 )
27552707 event, context = await self.event_creation_handler.create_new_client_event(
27562708 builder=builder
27572709 )
2758
2759 event_allowed = await self.third_party_event_rules.check_event_allowed(
2760 event, context
2761 )
2762 if not event_allowed:
2763 logger.warning(
2764 "Exchange of threepid invite %s forbidden by third-party rules", event
2765 )
2766 raise SynapseError(
2767 403, "This event is not allowed in this context", Codes.FORBIDDEN
2768 )
2769
27702710 event, context = await self.add_display_name_to_third_party_invite(
27712711 room_version, event_dict, event, context
27722712 )
29652905 return result["max_stream_id"]
29662906 else:
29672907 assert self.storage.persistence
2968 max_stream_token = await self.storage.persistence.persist_events(
2908
2909 # Note that this returns the events that were persisted, which may not be
2910 # the same as were passed in if some were deduplicated due to transaction IDs.
2911 events, max_stream_token = await self.storage.persistence.persist_events(
29692912 event_and_contexts, backfilled=backfilled
29702913 )
29712914
29722915 if self._ephemeral_messages_enabled:
2973 for (event, context) in event_and_contexts:
2916 for event in events:
29742917 # If there's an expiry timestamp on the event, schedule its expiry.
29752918 self._message_handler.maybe_schedule_expiry(event)
29762919
29772920 if not backfilled: # Never notify for backfilled events
2978 for event, _ in event_and_contexts:
2921 for event in events:
29792922 await self._notify_persisted_event(event, max_stream_token)
29802923
29812924 return max_stream_token.stream
30072950 elif event.internal_metadata.is_outlier():
30082951 return
30092952
2953 # the event has been persisted so it should have a stream ordering.
2954 assert event.internal_metadata.stream_ordering
2955
30102956 event_pos = PersistedEventPosition(
30112957 self._instance_name, event.internal_metadata.stream_ordering
30122958 )
1313 # limitations under the License.
1414
1515 import logging
16 from typing import TYPE_CHECKING
16 from typing import TYPE_CHECKING, Optional, Tuple
1717
1818 from twisted.internet import defer
1919
4646 self.state = hs.get_state_handler()
4747 self.clock = hs.get_clock()
4848 self.validator = EventValidator()
49 self.snapshot_cache = ResponseCache(hs, "initial_sync_cache")
49 self.snapshot_cache = ResponseCache(
50 hs, "initial_sync_cache"
51 ) # type: ResponseCache[Tuple[str, Optional[StreamToken], Optional[StreamToken], str, Optional[int], bool, bool]]
5052 self._event_serializer = hs.get_event_client_serializer()
5153 self.storage = hs.get_storage()
5254 self.state_store = self.storage.state
5355
54 def snapshot_all_rooms(
56 async def snapshot_all_rooms(
5557 self,
5658 user_id: str,
5759 pagin_config: PaginationConfig,
8385 include_archived,
8486 )
8587
86 return self.snapshot_cache.wrap(
88 return await self.snapshot_cache.wrap(
8789 key,
8890 self._snapshot_all_rooms,
8991 user_id,
290292 user_id, room_id, pagin_config, membership, is_peeking
291293 )
292294 elif membership == Membership.LEAVE:
295 # The member_event_id will always be available if membership is set
296 # to leave.
297 assert member_event_id
298
293299 result = await self._room_initial_sync_parted(
294300 user_id, room_id, pagin_config, membership, member_event_id, is_peeking
295301 )
312318 user_id: str,
313319 room_id: str,
314320 pagin_config: PaginationConfig,
315 membership: Membership,
321 membership: str,
316322 member_event_id: str,
317323 is_peeking: bool,
318324 ) -> JsonDict:
364370 user_id: str,
365371 room_id: str,
366372 pagin_config: PaginationConfig,
367 membership: Membership,
373 membership: str,
368374 is_peeking: bool,
369375 ) -> JsonDict:
370376 current_state = await self.state.get_current_state(room_id=room_id)
4949 from synapse.storage.databases.main.events_worker import EventRedactBehaviour
5050 from synapse.storage.state import StateFilter
5151 from synapse.types import Requester, RoomAlias, StreamToken, UserID, create_requester
52 from synapse.util import json_decoder
52 from synapse.util import json_decoder, json_encoder
5353 from synapse.util.async_helpers import Linearizer
54 from synapse.util.frozenutils import frozendict_json_encoder
5554 from synapse.util.metrics import measure_func
5655 from synapse.visibility import filter_events_for_client
5756
5857 from ._base import BaseHandler
5958
6059 if TYPE_CHECKING:
60 from synapse.events.third_party_rules import ThirdPartyEventRules
6161 from synapse.server import HomeServer
6262
6363 logger = logging.getLogger(__name__)
392392 self.action_generator = hs.get_action_generator()
393393
394394 self.spam_checker = hs.get_spam_checker()
395 self.third_party_event_rules = hs.get_third_party_event_rules()
395 self.third_party_event_rules = (
396 self.hs.get_third_party_event_rules()
397 ) # type: ThirdPartyEventRules
396398
397399 self._block_events_without_consent_error = (
398400 self.config.block_events_without_consent_error
399401 )
402
403 # we need to construct a ConsentURIBuilder here, as it checks that the necessary
404 # config options, but *only* if we have a configuration for which we are
405 # going to need it.
406 if self._block_events_without_consent_error:
407 self._consent_uri_builder = ConsentURIBuilder(self.config)
400408
401409 # Rooms which should be excluded from dummy insertion. (For instance,
402410 # those without local users who can send events into the room).
404412 # map from room id to time-of-last-attempt.
405413 #
406414 self._rooms_to_exclude_from_dummy_event_insertion = {} # type: Dict[str, int]
407
408 # we need to construct a ConsentURIBuilder here, as it checks that the necessary
409 # config options, but *only* if we have a configuration for which we are
410 # going to need it.
411 if self._block_events_without_consent_error:
412 self._consent_uri_builder = ConsentURIBuilder(self.config)
415 # The number of forward extremeities before a dummy event is sent.
416 self._dummy_events_threshold = hs.config.dummy_events_threshold
413417
414418 if (
415 not self.config.worker_app
419 self.config.run_background_tasks
416420 and self.config.cleanup_extremities_with_dummy_events
417421 ):
418422 self.clock.looping_call(
427431
428432 self._ephemeral_events_enabled = hs.config.enable_ephemeral_messages
429433
430 self._dummy_events_threshold = hs.config.dummy_events_threshold
431
432434 async def create_event(
433435 self,
434436 requester: Requester,
435437 event_dict: dict,
436 token_id: Optional[str] = None,
437438 txn_id: Optional[str] = None,
438439 prev_event_ids: Optional[List[str]] = None,
440 auth_event_ids: Optional[List[str]] = None,
439441 require_consent: bool = True,
440442 ) -> Tuple[EventBase, EventContext]:
441443 """
449451 Args:
450452 requester
451453 event_dict: An entire event
452 token_id
453454 txn_id
454455 prev_event_ids:
455456 the forward extremities to use as the prev_events for the
456457 new event.
457458
458459 If None, they will be requested from the database.
460
461 auth_event_ids:
462 The event ids to use as the auth_events for the new event.
463 Should normally be left as None, which will cause them to be calculated
464 based on the room state at the prev_events.
465
459466 require_consent: Whether to check if the requester has
460467 consented to the privacy policy.
461468 Raises:
507514 if require_consent and not is_exempt:
508515 await self.assert_accepted_privacy_policy(requester)
509516
510 if token_id is not None:
511 builder.internal_metadata.token_id = token_id
517 if requester.access_token_id is not None:
518 builder.internal_metadata.token_id = requester.access_token_id
512519
513520 if txn_id is not None:
514521 builder.internal_metadata.txn_id = txn_id
515522
516523 event, context = await self.create_new_client_event(
517 builder=builder, requester=requester, prev_event_ids=prev_event_ids,
524 builder=builder,
525 requester=requester,
526 prev_event_ids=prev_event_ids,
527 auth_event_ids=auth_event_ids,
518528 )
519529
520530 # In an ideal world we wouldn't need the second part of this condition. However,
634644 msg = self._block_events_without_consent_error % {"consent_uri": consent_uri}
635645 raise ConsentNotGivenError(msg=msg, consent_uri=consent_uri)
636646
637 async def send_nonmember_event(
638 self,
639 requester: Requester,
640 event: EventBase,
641 context: EventContext,
642 ratelimit: bool = True,
643 ignore_shadow_ban: bool = False,
644 ) -> int:
645 """
646 Persists and notifies local clients and federation of an event.
647
648 Args:
649 requester: The requester sending the event.
650 event: The event to send.
651 context: The context of the event.
652 ratelimit: Whether to rate limit this send.
653 ignore_shadow_ban: True if shadow-banned users should be allowed to
654 send this event.
655
656 Return:
657 The stream_id of the persisted event.
658
659 Raises:
660 ShadowBanError if the requester has been shadow-banned.
661 """
662 if event.type == EventTypes.Member:
663 raise SynapseError(
664 500, "Tried to send member event through non-member codepath"
665 )
666
667 if not ignore_shadow_ban and requester.shadow_banned:
668 # We randomly sleep a bit just to annoy the requester.
669 await self.clock.sleep(random.randint(1, 10))
670 raise ShadowBanError()
671
672 user = UserID.from_string(event.sender)
673
674 assert self.hs.is_mine(user), "User must be our own: %s" % (user,)
675
676 if event.is_state():
677 prev_event = await self.deduplicate_state_event(event, context)
678 if prev_event is not None:
679 logger.info(
680 "Not bothering to persist state event %s duplicated by %s",
681 event.event_id,
682 prev_event.event_id,
683 )
684 return await self.store.get_stream_id_for_event(prev_event.event_id)
685
686 return await self.handle_new_client_event(
687 requester=requester, event=event, context=context, ratelimit=ratelimit
688 )
689
690647 async def deduplicate_state_event(
691648 self, event: EventBase, context: EventContext
692649 ) -> Optional[EventBase]:
727684 """
728685 Creates an event, then sends it.
729686
730 See self.create_event and self.send_nonmember_event.
687 See self.create_event and self.handle_new_client_event.
731688
732689 Args:
733690 requester: The requester sending the event.
737694 ignore_shadow_ban: True if shadow-banned users should be allowed to
738695 send this event.
739696
697 Returns:
698 The event, and its stream ordering (if deduplication happened,
699 the previous, duplicate event).
700
740701 Raises:
741702 ShadowBanError if the requester has been shadow-banned.
742703 """
704
705 if event_dict["type"] == EventTypes.Member:
706 raise SynapseError(
707 500, "Tried to send member event through non-member codepath"
708 )
709
743710 if not ignore_shadow_ban and requester.shadow_banned:
744711 # We randomly sleep a bit just to annoy the requester.
745712 await self.clock.sleep(random.randint(1, 10))
751718 # extremities to pile up, which in turn leads to state resolution
752719 # taking longer.
753720 with (await self.limiter.queue(event_dict["room_id"])):
721 if txn_id and requester.access_token_id:
722 existing_event_id = await self.store.get_event_id_from_transaction_id(
723 event_dict["room_id"],
724 requester.user.to_string(),
725 requester.access_token_id,
726 txn_id,
727 )
728 if existing_event_id:
729 event = await self.store.get_event(existing_event_id)
730 # we know it was persisted, so must have a stream ordering
731 assert event.internal_metadata.stream_ordering
732 return event, event.internal_metadata.stream_ordering
733
754734 event, context = await self.create_event(
755 requester, event_dict, token_id=requester.access_token_id, txn_id=txn_id
735 requester, event_dict, txn_id=txn_id
736 )
737
738 assert self.hs.is_mine_id(event.sender), "User must be our own: %s" % (
739 event.sender,
756740 )
757741
758742 spam_error = self.spam_checker.check_event_for_spam(event)
761745 spam_error = "Spam is not permitted here"
762746 raise SynapseError(403, spam_error, Codes.FORBIDDEN)
763747
764 stream_id = await self.send_nonmember_event(
765 requester,
766 event,
767 context,
748 ev = await self.handle_new_client_event(
749 requester=requester,
750 event=event,
751 context=context,
768752 ratelimit=ratelimit,
769753 ignore_shadow_ban=ignore_shadow_ban,
770754 )
771 return event, stream_id
755
756 # we know it was persisted, so must have a stream ordering
757 assert ev.internal_metadata.stream_ordering
758 return ev, ev.internal_metadata.stream_ordering
772759
773760 @measure_func("create_new_client_event")
774761 async def create_new_client_event(
776763 builder: EventBuilder,
777764 requester: Optional[Requester] = None,
778765 prev_event_ids: Optional[List[str]] = None,
766 auth_event_ids: Optional[List[str]] = None,
779767 ) -> Tuple[EventBase, EventContext]:
780768 """Create a new event for a local client
781769
787775 new event.
788776
789777 If None, they will be requested from the database.
778
779 auth_event_ids:
780 The event ids to use as the auth_events for the new event.
781 Should normally be left as None, which will cause them to be calculated
782 based on the room state at the prev_events.
790783
791784 Returns:
792785 Tuple of created event, context
809802 builder.type == EventTypes.Create or len(prev_event_ids) > 0
810803 ), "Attempting to create an event with no prev_events"
811804
812 event = await builder.build(prev_event_ids=prev_event_ids)
805 event = await builder.build(
806 prev_event_ids=prev_event_ids, auth_event_ids=auth_event_ids
807 )
813808 context = await self.state.compute_event_context(event)
814809 if requester:
815810 context.app_service = requester.app_service
811
812 third_party_result = await self.third_party_event_rules.check_event_allowed(
813 event, context
814 )
815 if not third_party_result:
816 logger.info(
817 "Event %s forbidden by third-party rules", event,
818 )
819 raise SynapseError(
820 403, "This event is not allowed in this context", Codes.FORBIDDEN
821 )
822 elif isinstance(third_party_result, dict):
823 # the third-party rules want to replace the event. We'll need to build a new
824 # event.
825 event, context = await self._rebuild_event_after_third_party_rules(
826 third_party_result, event
827 )
816828
817829 self.validator.validate_new(event, self.config)
818830
842854 context: EventContext,
843855 ratelimit: bool = True,
844856 extra_users: List[UserID] = [],
845 ) -> int:
846 """Processes a new event. This includes checking auth, persisting it,
857 ignore_shadow_ban: bool = False,
858 ) -> EventBase:
859 """Processes a new event.
860
861 This includes deduplicating, checking auth, persisting,
847862 notifying users, sending to remote servers, etc.
848863
849864 If called from a worker will hit out to the master process for final
856871 ratelimit
857872 extra_users: Any extra users to notify about event
858873
874 ignore_shadow_ban: True if shadow-banned users should be allowed to
875 send this event.
876
859877 Return:
860 The stream_id of the persisted event.
861 """
878 If the event was deduplicated, the previous, duplicate, event. Otherwise,
879 `event`.
880
881 Raises:
882 ShadowBanError if the requester has been shadow-banned.
883 """
884
885 # we don't apply shadow-banning to membership events here. Invites are blocked
886 # higher up the stack, and we allow shadow-banned users to send join and leave
887 # events as normal.
888 if (
889 event.type != EventTypes.Member
890 and not ignore_shadow_ban
891 and requester.shadow_banned
892 ):
893 # We randomly sleep a bit just to annoy the requester.
894 await self.clock.sleep(random.randint(1, 10))
895 raise ShadowBanError()
896
897 if event.is_state():
898 prev_event = await self.deduplicate_state_event(event, context)
899 if prev_event is not None:
900 logger.info(
901 "Not bothering to persist state event %s duplicated by %s",
902 event.event_id,
903 prev_event.event_id,
904 )
905 return prev_event
862906
863907 if event.is_state() and (event.type, event.state_key) == (
864908 EventTypes.Create,
867911 room_version = event.content.get("room_version", RoomVersions.V1.identifier)
868912 else:
869913 room_version = await self.store.get_room_version_id(event.room_id)
870
871 event_allowed = await self.third_party_event_rules.check_event_allowed(
872 event, context
873 )
874 if not event_allowed:
875 raise SynapseError(
876 403, "This event is not allowed in this context", Codes.FORBIDDEN
877 )
878914
879915 if event.internal_metadata.is_out_of_band_membership():
880916 # the only sort of out-of-band-membership events we expect to see here
890926
891927 # Ensure that we can round trip before trying to persist in db
892928 try:
893 dump = frozendict_json_encoder.encode(event.content)
929 dump = json_encoder.encode(event.content)
894930 json_decoder.decode(dump)
895931 except Exception:
896932 logger.exception("Failed to encode content: %r", event.content)
913949 extra_users=extra_users,
914950 )
915951 stream_id = result["stream_id"]
916 event.internal_metadata.stream_ordering = stream_id
917 return stream_id
918
919 stream_id = await self.persist_and_notify_client_event(
952 event_id = result["event_id"]
953 if event_id != event.event_id:
954 # If we get a different event back then it means that its
955 # been de-duplicated, so we replace the given event with the
956 # one already persisted.
957 event = await self.store.get_event(event_id)
958 else:
959 # If we newly persisted the event then we need to update its
960 # stream_ordering entry manually (as it was persisted on
961 # another worker).
962 event.internal_metadata.stream_ordering = stream_id
963 return event
964
965 event = await self.persist_and_notify_client_event(
920966 requester, event, context, ratelimit=ratelimit, extra_users=extra_users
921967 )
922968
923 return stream_id
969 return event
924970 except Exception:
925971 # Ensure that we actually remove the entries in the push actions
926972 # staging area, if we calculated them.
9651011 context: EventContext,
9661012 ratelimit: bool = True,
9671013 extra_users: List[UserID] = [],
968 ) -> int:
1014 ) -> EventBase:
9691015 """Called when we have fully built the event, have already
9701016 calculated the push actions for the event, and checked auth.
9711017
9721018 This should only be run on the instance in charge of persisting events.
1019
1020 Returns:
1021 The persisted event. This may be different than the given event if
1022 it was de-duplicated (e.g. because we had already persisted an
1023 event with the same transaction ID.)
9731024 """
9741025 assert self.storage.persistence is not None
9751026 assert self._events_shard_config.should_handle(
10171068
10181069 # Check the alias is currently valid (if it has changed).
10191070 room_alias_str = event.content.get("alias", None)
1020 directory_handler = self.hs.get_handlers().directory_handler
1071 directory_handler = self.hs.get_directory_handler()
10211072 if room_alias_str and room_alias_str != original_alias:
10221073 await self._validate_canonical_alias(
10231074 directory_handler, room_alias_str, event.room_id
10431094 directory_handler, alias_str, event.room_id
10441095 )
10451096
1046 federation_handler = self.hs.get_handlers().federation_handler
1097 federation_handler = self.hs.get_federation_handler()
10471098
10481099 if event.type == EventTypes.Member:
10491100 if event.content["membership"] == Membership.INVITE:
11371188 if prev_state_ids:
11381189 raise AuthError(403, "Changing the room create event is forbidden")
11391190
1140 event_pos, max_stream_token = await self.storage.persistence.persist_event(
1141 event, context=context
1142 )
1191 # Note that this returns the event that was persisted, which may not be
1192 # the same as we passed in if it was deduplicated due transaction IDs.
1193 (
1194 event,
1195 event_pos,
1196 max_stream_token,
1197 ) = await self.storage.persistence.persist_event(event, context=context)
11431198
11441199 if self._ephemeral_events_enabled:
11451200 # If there's an expiry timestamp on the event, schedule its expiry.
11601215 # matters as sometimes presence code can take a while.
11611216 run_in_background(self._bump_active_time, requester.user)
11621217
1163 return event_pos.stream
1218 return event
11641219
11651220 async def _bump_active_time(self, user: UserID) -> None:
11661221 try:
12311286
12321287 # Since this is a dummy-event it is OK if it is sent by a
12331288 # shadow-banned user.
1234 await self.send_nonmember_event(
1289 await self.handle_new_client_event(
12351290 requester, event, context, ratelimit=False, ignore_shadow_ban=True,
12361291 )
12371292 return True
12591314 room_id,
12601315 )
12611316 del self._rooms_to_exclude_from_dummy_event_insertion[room_id]
1317
1318 async def _rebuild_event_after_third_party_rules(
1319 self, third_party_result: dict, original_event: EventBase
1320 ) -> Tuple[EventBase, EventContext]:
1321 # the third_party_event_rules want to replace the event.
1322 # we do some basic checks, and then return the replacement event and context.
1323
1324 # Construct a new EventBuilder and validate it, which helps with the
1325 # rest of these checks.
1326 try:
1327 builder = self.event_builder_factory.for_room_version(
1328 original_event.room_version, third_party_result
1329 )
1330 self.validator.validate_builder(builder)
1331 except SynapseError as e:
1332 raise Exception(
1333 "Third party rules module created an invalid event: " + e.msg,
1334 )
1335
1336 immutable_fields = [
1337 # changing the room is going to break things: we've already checked that the
1338 # room exists, and are holding a concurrency limiter token for that room.
1339 # Also, we might need to use a different room version.
1340 "room_id",
1341 # changing the type or state key might work, but we'd need to check that the
1342 # calling functions aren't making assumptions about them.
1343 "type",
1344 "state_key",
1345 ]
1346
1347 for k in immutable_fields:
1348 if getattr(builder, k, None) != original_event.get(k):
1349 raise Exception(
1350 "Third party rules module created an invalid event: "
1351 "cannot change field " + k
1352 )
1353
1354 # check that the new sender belongs to this HS
1355 if not self.hs.is_mine_id(builder.sender):
1356 raise Exception(
1357 "Third party rules module created an invalid event: "
1358 "invalid sender " + builder.sender
1359 )
1360
1361 # copy over the original internal metadata
1362 for k, v in original_event.internal_metadata.get_dict().items():
1363 setattr(builder.internal_metadata, k, v)
1364
1365 # the event type hasn't changed, so there's no point in re-calculating the
1366 # auth events.
1367 event = await builder.build(
1368 prev_event_ids=original_event.prev_event_ids(),
1369 auth_event_ids=original_event.auth_event_ids(),
1370 )
1371
1372 # we rebuild the event context, to be on the safe side. If nothing else,
1373 # delta_ids might need an update.
1374 context = await self.state.compute_event_context(event)
1375 return event, context
9595 self.hs = hs
9696 self._callback_url = hs.config.oidc_callback_url # type: str
9797 self._scopes = hs.config.oidc_scopes # type: List[str]
98 self._user_profile_method = hs.config.oidc_user_profile_method # type: str
9899 self._client_auth = ClientAuth(
99100 hs.config.oidc_client_id,
100101 hs.config.oidc_client_secret,
195196 % (m["response_types_supported"],)
196197 )
197198
198 # If the openid scope was not requested, we need a userinfo endpoint to fetch user infos
199 # Ensure there's a userinfo endpoint to fetch from if it is required.
199200 if self._uses_userinfo:
200201 if m.get("userinfo_endpoint") is None:
201202 raise ValueError(
202 'provider has no "userinfo_endpoint", even though it is required because the "openid" scope is not requested'
203 'provider has no "userinfo_endpoint", even though it is required'
203204 )
204205 else:
205206 # If we're not using userinfo, we need a valid jwks to validate the ID token
219220 ``access_token`` with the ``userinfo_endpoint``.
220221 """
221222
222 # Maybe that should be user-configurable and not inferred?
223 return "openid" not in self._scopes
223 return (
224 "openid" not in self._scopes
225 or self._user_profile_method == "userinfo_endpoint"
226 )
224227
225228 async def load_metadata(self) -> OpenIDProviderMetadata:
226229 """Load and validate the provider metadata.
9191 self._retention_allowed_lifetime_min = hs.config.retention_allowed_lifetime_min
9292 self._retention_allowed_lifetime_max = hs.config.retention_allowed_lifetime_max
9393
94 if hs.config.retention_enabled:
94 if hs.config.run_background_tasks and hs.config.retention_enabled:
9595 # Run the purge jobs described in the configuration file.
9696 for job in hs.config.retention_purge_jobs:
9797 logger.info("Setting up purge job with config: %s", job)
382382 "room_key", leave_token
383383 )
384384
385 await self.hs.get_handlers().federation_handler.maybe_backfill(
385 await self.hs.get_federation_handler().maybe_backfill(
386386 room_id, curr_topo, limit=pagin_config.limit,
387387 )
388388
1515
1616 import logging
1717 import re
18 from typing import TYPE_CHECKING
1819
1920 from synapse.api.errors import Codes, PasswordRefusedError
21
22 if TYPE_CHECKING:
23 from synapse.app.homeserver import HomeServer
2024
2125 logger = logging.getLogger(__name__)
2226
2327
2428 class PasswordPolicyHandler:
25 def __init__(self, hs):
29 def __init__(self, hs: "HomeServer"):
2630 self.policy = hs.config.password_policy
2731 self.enabled = hs.config.password_policy_enabled
2832
3236 self.regexp_uppercase = re.compile("[A-Z]")
3337 self.regexp_lowercase = re.compile("[a-z]")
3438
35 def validate_password(self, password):
39 def validate_password(self, password: str) -> None:
3640 """Checks whether a given password complies with the server's policy.
3741
3842 Args:
39 password (str): The password to check against the server's policy.
43 password: The password to check against the server's policy.
4044
4145 Raises:
4246 PasswordRefusedError: The password doesn't comply with the server's policy.
1111 # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
1212 # See the License for the specific language governing permissions and
1313 # limitations under the License.
14
1514 import logging
1615 import random
16 from typing import TYPE_CHECKING, Optional
1717
1818 from synapse.api.errors import (
1919 AuthError,
2323 StoreError,
2424 SynapseError,
2525 )
26 from synapse.metrics.background_process_metrics import run_as_background_process
27 from synapse.types import UserID, create_requester, get_domain_from_id
26 from synapse.metrics.background_process_metrics import wrap_as_background_process
27 from synapse.types import (
28 JsonDict,
29 Requester,
30 UserID,
31 create_requester,
32 get_domain_from_id,
33 )
2834
2935 from ._base import BaseHandler
36
37 if TYPE_CHECKING:
38 from synapse.app.homeserver import HomeServer
3039
3140 logger = logging.getLogger(__name__)
3241
3443 MAX_AVATAR_URL_LEN = 1000
3544
3645
37 class BaseProfileHandler(BaseHandler):
46 class ProfileHandler(BaseHandler):
3847 """Handles fetching and updating user profile information.
3948
40 BaseProfileHandler can be instantiated directly on workers and will
41 delegate to master when necessary. The master process should use the
42 subclass MasterProfileHandler
49 ProfileHandler can be instantiated directly on workers and will
50 delegate to master when necessary.
4351 """
4452
45 def __init__(self, hs):
53 PROFILE_UPDATE_MS = 60 * 1000
54 PROFILE_UPDATE_EVERY_MS = 24 * 60 * 60 * 1000
55
56 def __init__(self, hs: "HomeServer"):
4657 super().__init__(hs)
4758
4859 self.federation = hs.get_federation_client()
5263
5364 self.user_directory_handler = hs.get_user_directory_handler()
5465
55 async def get_profile(self, user_id):
66 if hs.config.run_background_tasks:
67 self.clock.looping_call(
68 self._update_remote_profile_cache, self.PROFILE_UPDATE_MS
69 )
70
71 async def get_profile(self, user_id: str) -> JsonDict:
5672 target_user = UserID.from_string(user_id)
5773
5874 if self.hs.is_mine(target_user):
8399 except HttpResponseException as e:
84100 raise e.to_synapse_error()
85101
86 async def get_profile_from_cache(self, user_id):
102 async def get_profile_from_cache(self, user_id: str) -> JsonDict:
87103 """Get the profile information from our local cache. If the user is
88104 ours then the profile information will always be corect. Otherwise,
89105 it may be out of date/missing.
107123 profile = await self.store.get_from_remote_profile_cache(user_id)
108124 return profile or {}
109125
110 async def get_displayname(self, target_user):
126 async def get_displayname(self, target_user: UserID) -> str:
111127 if self.hs.is_mine(target_user):
112128 try:
113129 displayname = await self.store.get_profile_displayname(
135151 return result["displayname"]
136152
137153 async def set_displayname(
138 self, target_user, requester, new_displayname, by_admin=False
139 ):
154 self,
155 target_user: UserID,
156 requester: Requester,
157 new_displayname: str,
158 by_admin: bool = False,
159 ) -> None:
140160 """Set the displayname of a user
141161
142162 Args:
143 target_user (UserID): the user whose displayname is to be changed.
144 requester (Requester): The user attempting to make this change.
145 new_displayname (str): The displayname to give this user.
146 by_admin (bool): Whether this change was made by an administrator.
163 target_user: the user whose displayname is to be changed.
164 requester: The user attempting to make this change.
165 new_displayname: The displayname to give this user.
166 by_admin: Whether this change was made by an administrator.
147167 """
148168 if not self.hs.is_mine(target_user):
149169 raise SynapseError(400, "User is not hosted on this homeserver")
168188 400, "Displayname is too long (max %i)" % (MAX_DISPLAYNAME_LEN,)
169189 )
170190
191 displayname_to_set = new_displayname # type: Optional[str]
171192 if new_displayname == "":
172 new_displayname = None
193 displayname_to_set = None
173194
174195 # If the admin changes the display name of a user, the requesting user cannot send
175196 # the join event to update the displayname in the rooms.
177198 if by_admin:
178199 requester = create_requester(target_user)
179200
180 await self.store.set_profile_displayname(target_user.localpart, new_displayname)
201 await self.store.set_profile_displayname(
202 target_user.localpart, displayname_to_set
203 )
181204
182205 if self.hs.config.user_directory_search_all_users:
183206 profile = await self.store.get_profileinfo(target_user.localpart)
187210
188211 await self._update_join_states(requester, target_user)
189212
190 async def get_avatar_url(self, target_user):
213 async def get_avatar_url(self, target_user: UserID) -> str:
191214 if self.hs.is_mine(target_user):
192215 try:
193216 avatar_url = await self.store.get_profile_avatar_url(
214237 return result["avatar_url"]
215238
216239 async def set_avatar_url(
217 self, target_user, requester, new_avatar_url, by_admin=False
240 self,
241 target_user: UserID,
242 requester: Requester,
243 new_avatar_url: str,
244 by_admin: bool = False,
218245 ):
219246 """Set a new avatar URL for a user.
220247
221248 Args:
222 target_user (UserID): the user whose avatar URL is to be changed.
223 requester (Requester): The user attempting to make this change.
224 new_avatar_url (str): The avatar URL to give this user.
225 by_admin (bool): Whether this change was made by an administrator.
249 target_user: the user whose avatar URL is to be changed.
250 requester: The user attempting to make this change.
251 new_avatar_url: The avatar URL to give this user.
252 by_admin: Whether this change was made by an administrator.
226253 """
227254 if not self.hs.is_mine(target_user):
228255 raise SynapseError(400, "User is not hosted on this homeserver")
259286
260287 await self._update_join_states(requester, target_user)
261288
262 async def on_profile_query(self, args):
289 async def on_profile_query(self, args: JsonDict) -> JsonDict:
263290 user = UserID.from_string(args["user_id"])
264291 if not self.hs.is_mine(user):
265292 raise SynapseError(400, "User is not hosted on this homeserver")
284311
285312 return response
286313
287 async def _update_join_states(self, requester, target_user):
314 async def _update_join_states(
315 self, requester: Requester, target_user: UserID
316 ) -> None:
288317 if not self.hs.is_mine(target_user):
289318 return
290319
315344 "Failed to update join event for room %s - %s", room_id, str(e)
316345 )
317346
318 async def check_profile_query_allowed(self, target_user, requester=None):
347 async def check_profile_query_allowed(
348 self, target_user: UserID, requester: Optional[UserID] = None
349 ) -> None:
319350 """Checks whether a profile query is allowed. If the
320351 'require_auth_for_profile_requests' config flag is set to True and a
321352 'requester' is provided, the query is only allowed if the two users
322353 share a room.
323354
324355 Args:
325 target_user (UserID): The owner of the queried profile.
326 requester (None|UserID): The user querying for the profile.
356 target_user: The owner of the queried profile.
357 requester: The user querying for the profile.
327358
328359 Raises:
329360 SynapseError(403): The two users share no room, or ne user couldn't
362393 raise SynapseError(403, "Profile isn't available", Codes.FORBIDDEN)
363394 raise
364395
365
366 class MasterProfileHandler(BaseProfileHandler):
367 PROFILE_UPDATE_MS = 60 * 1000
368 PROFILE_UPDATE_EVERY_MS = 24 * 60 * 60 * 1000
369
370 def __init__(self, hs):
371 super().__init__(hs)
372
373 assert hs.config.worker_app is None
374
375 self.clock.looping_call(
376 self._start_update_remote_profile_cache, self.PROFILE_UPDATE_MS
377 )
378
379 def _start_update_remote_profile_cache(self):
380 return run_as_background_process(
381 "Update remote profile", self._update_remote_profile_cache
382 )
383
396 @wrap_as_background_process("Update remote profile")
384397 async def _update_remote_profile_cache(self):
385398 """Called periodically to check profiles of remote users we haven't
386399 checked in a while.
1313 # limitations under the License.
1414
1515 import logging
16 from typing import TYPE_CHECKING
1617
1718 from synapse.util.async_helpers import Linearizer
1819
1920 from ._base import BaseHandler
2021
22 if TYPE_CHECKING:
23 from synapse.app.homeserver import HomeServer
24
2125 logger = logging.getLogger(__name__)
2226
2327
2428 class ReadMarkerHandler(BaseHandler):
25 def __init__(self, hs):
29 def __init__(self, hs: "HomeServer"):
2630 super().__init__(hs)
2731 self.server_name = hs.config.server_name
2832 self.store = hs.get_datastore()
2933 self.read_marker_linearizer = Linearizer(name="read_marker")
3034 self.notifier = hs.get_notifier()
3135
32 async def received_client_read_marker(self, room_id, user_id, event_id):
36 async def received_client_read_marker(
37 self, room_id: str, user_id: str, event_id: str
38 ) -> None:
3339 """Updates the read marker for a given user in a given room if the event ID given
3440 is ahead in the stream relative to the current read marker.
3541
1212 # See the License for the specific language governing permissions and
1313 # limitations under the License.
1414 import logging
15 from typing import List, Tuple
1516
17 from synapse.appservice import ApplicationService
1618 from synapse.handlers._base import BaseHandler
17 from synapse.types import ReadReceipt, get_domain_from_id
19 from synapse.types import JsonDict, ReadReceipt, get_domain_from_id
1820 from synapse.util.async_helpers import maybe_awaitable
1921
2022 logger = logging.getLogger(__name__)
139141
140142 return (events, to_key)
141143
144 async def get_new_events_as(
145 self, from_key: int, service: ApplicationService
146 ) -> Tuple[List[JsonDict], int]:
147 """Returns a set of new receipt events that an appservice
148 may be interested in.
149
150 Args:
151 from_key: the stream position at which events should be fetched from
152 service: The appservice which may be interested
153 """
154 from_key = int(from_key)
155 to_key = self.get_current_key()
156
157 if from_key == to_key:
158 return [], to_key
159
160 # We first need to fetch all new receipts
161 rooms_to_events = await self.store.get_linearized_receipts_for_all_rooms(
162 from_key=from_key, to_key=to_key
163 )
164
165 # Then filter down to rooms that the AS can read
166 events = []
167 for room_id, event in rooms_to_events.items():
168 if not await service.matches_user_in_member_list(room_id, self.store):
169 continue
170
171 events.append(event)
172
173 return (events, to_key)
174
142175 def get_current_key(self, direction="f"):
143176 return self.store.get_max_receipt_stream_id()
4747 self._auth_handler = hs.get_auth_handler()
4848 self.profile_handler = hs.get_profile_handler()
4949 self.user_directory_handler = hs.get_user_directory_handler()
50 self.identity_handler = self.hs.get_handlers().identity_handler
50 self.identity_handler = self.hs.get_identity_handler()
5151 self.ratelimiter = hs.get_registration_ratelimiter()
5252 self.macaroon_gen = hs.get_macaroon_generator()
5353 self._server_notices_mxid = hs.config.server_notices_mxid
119119 # subsequent requests
120120 self._upgrade_response_cache = ResponseCache(
121121 hs, "room_upgrade", timeout_ms=FIVE_MINUTES_IN_MS
122 )
122 ) # type: ResponseCache[Tuple[str, str]]
123123 self._server_notices_mxid = hs.config.server_notices_mxid
124124
125125 self.third_party_event_rules = hs.get_third_party_event_rules()
184184 ShadowBanError if the requester is shadow-banned.
185185 """
186186 user_id = requester.user.to_string()
187 assert self.hs.is_mine_id(user_id), "User must be our own: %s" % (user_id,)
187188
188189 # start by allocating a new room id
189190 r = await self.store.get_room(old_room_id)
212213 "replacement_room": new_room_id,
213214 },
214215 },
215 token_id=requester.access_token_id,
216216 )
217217 old_room_version = await self.store.get_room_version_id(old_room_id)
218218 await self.auth.check_from_context(
228228 )
229229
230230 # now send the tombstone
231 await self.event_creation_handler.send_nonmember_event(
232 requester, tombstone_event, tombstone_context
231 await self.event_creation_handler.handle_new_client_event(
232 requester=requester, event=tombstone_event, context=tombstone_context,
233233 )
234234
235235 old_room_state = await tombstone_context.get_current_state_ids()
680680 creator_id=user_id, is_public=is_public, room_version=room_version,
681681 )
682682
683 directory_handler = self.hs.get_handlers().directory_handler
683 # Check whether this visibility value is blocked by a third party module
684 allowed_by_third_party_rules = await (
685 self.third_party_event_rules.check_visibility_can_be_modified(
686 room_id, visibility
687 )
688 )
689 if not allowed_by_third_party_rules:
690 raise SynapseError(403, "Room visibility value not allowed.")
691
692 directory_handler = self.hs.get_directory_handler()
684693 if room_alias:
685694 await directory_handler.create_association(
686695 requester=requester,
961970 try:
962971 random_string = stringutils.random_string(18)
963972 gen_room_id = RoomID(random_string, self.hs.hostname).to_string()
964 if isinstance(gen_room_id, bytes):
965 gen_room_id = gen_room_id.decode("utf-8")
966973 await self.store.store_room(
967974 room_id=gen_room_id,
968975 room_creator_user_id=creator_id,
1616 import logging
1717 import random
1818 from http import HTTPStatus
19 from typing import TYPE_CHECKING, Iterable, List, Optional, Tuple, Union
20
21 from unpaddedbase64 import encode_base64
19 from typing import TYPE_CHECKING, Iterable, List, Optional, Tuple
2220
2321 from synapse import types
24 from synapse.api.constants import MAX_DEPTH, EventTypes, Membership
22 from synapse.api.constants import AccountDataTypes, EventTypes, Membership
2523 from synapse.api.errors import (
2624 AuthError,
2725 Codes,
3028 SynapseError,
3129 )
3230 from synapse.api.ratelimiting import Ratelimiter
33 from synapse.api.room_versions import EventFormatVersions
34 from synapse.crypto.event_signing import compute_event_reference_hash
3531 from synapse.events import EventBase
36 from synapse.events.builder import create_local_event_from_event_dict
3732 from synapse.events.snapshot import EventContext
38 from synapse.events.validator import EventValidator
3933 from synapse.storage.roommember import RoomsForUser
4034 from synapse.types import JsonDict, Requester, RoomAlias, RoomID, StateMap, UserID
4135 from synapse.util.async_helpers import Linearizer
6357 self.state_handler = hs.get_state_handler()
6458 self.config = hs.config
6559
66 self.federation_handler = hs.get_handlers().federation_handler
67 self.directory_handler = hs.get_handlers().directory_handler
68 self.identity_handler = hs.get_handlers().identity_handler
60 self.federation_handler = hs.get_federation_handler()
61 self.directory_handler = hs.get_directory_handler()
62 self.identity_handler = hs.get_identity_handler()
6963 self.registration_handler = hs.get_registration_handler()
7064 self.profile_handler = hs.get_profile_handler()
7165 self.event_creation_handler = hs.get_event_creation_handler()
170164 if requester.is_guest:
171165 content["kind"] = "guest"
172166
167 # Check if we already have an event with a matching transaction ID. (We
168 # do this check just before we persist an event as well, but may as well
169 # do it up front for efficiency.)
170 if txn_id and requester.access_token_id:
171 existing_event_id = await self.store.get_event_id_from_transaction_id(
172 room_id, requester.user.to_string(), requester.access_token_id, txn_id,
173 )
174 if existing_event_id:
175 event_pos = await self.store.get_position_for_event(existing_event_id)
176 return existing_event_id, event_pos.stream
177
173178 event, context = await self.event_creation_handler.create_event(
174179 requester,
175180 {
181186 # For backwards compatibility:
182187 "membership": membership,
183188 },
184 token_id=requester.access_token_id,
185189 txn_id=txn_id,
186190 prev_event_ids=prev_event_ids,
187191 require_consent=require_consent,
188192 )
189
190 # Check if this event matches the previous membership event for the user.
191 duplicate = await self.event_creation_handler.deduplicate_state_event(
192 event, context
193 )
194 if duplicate is not None:
195 # Discard the new event since this membership change is a no-op.
196 _, stream_id = await self.store.get_event_ordering(duplicate.event_id)
197 return duplicate.event_id, stream_id
198193
199194 prev_state_ids = await context.get_prev_state_ids()
200195
220215 retry_after_ms=int(1000 * (time_allowed - time_now_s))
221216 )
222217
223 stream_id = await self.event_creation_handler.handle_new_client_event(
218 result_event = await self.event_creation_handler.handle_new_client_event(
224219 requester, event, context, extra_users=[target], ratelimit=ratelimit,
225220 )
226221
230225 if prev_member_event.membership == Membership.JOIN:
231226 await self._user_left_room(target, room_id)
232227
233 return event.event_id, stream_id
228 # we know it was persisted, so should have a stream ordering
229 assert result_event.internal_metadata.stream_ordering
230 return result_event.event_id, result_event.internal_metadata.stream_ordering
234231
235232 async def copy_room_tags_and_direct_to_room(
236233 self, old_room_id, new_room_id, user_id
246243 user_account_data, _ = await self.store.get_account_data_for_user(user_id)
247244
248245 # Copy direct message state if applicable
249 direct_rooms = user_account_data.get("m.direct", {})
246 direct_rooms = user_account_data.get(AccountDataTypes.DIRECT, {})
250247
251248 # Check which key this room is under
252249 if isinstance(direct_rooms, dict):
257254
258255 # Save back to user's m.direct account data
259256 await self.store.add_account_data_for_user(
260 user_id, "m.direct", direct_rooms
257 user_id, AccountDataTypes.DIRECT, direct_rooms
261258 )
262259 break
263260
440437 same_membership = old_membership == effective_membership_state
441438 same_sender = requester.user.to_string() == old_state.sender
442439 if same_sender and same_membership and same_content:
443 _, stream_id = await self.store.get_event_ordering(
444 old_state.event_id
445 )
440 # duplicate event.
441 # we know it was persisted, so must have a stream ordering.
442 assert old_state.internal_metadata.stream_ordering
446443 return (
447444 old_state.event_id,
448 stream_id,
445 old_state.internal_metadata.stream_ordering,
449446 )
450447
451448 if old_membership in ["ban", "leave"] and action == "kick":
641638
642639 async def send_membership_event(
643640 self,
644 requester: Requester,
641 requester: Optional[Requester],
645642 event: EventBase,
646643 context: EventContext,
647644 ratelimit: bool = True,
671668 else:
672669 requester = types.create_requester(target_user)
673670
674 prev_event = await self.event_creation_handler.deduplicate_state_event(
675 event, context
676 )
677 if prev_event is not None:
678 return
679
680671 prev_state_ids = await context.get_prev_state_ids()
681672 if event.membership == Membership.JOIN:
682673 if requester.is_guest:
691682 if is_blocked:
692683 raise SynapseError(403, "This room has been blocked on this server")
693684
694 await self.event_creation_handler.handle_new_client_event(
685 event = await self.event_creation_handler.handle_new_client_event(
695686 requester, event, context, extra_users=[target_user], ratelimit=ratelimit
696687 )
697688
11341125
11351126 room_id = invite_event.room_id
11361127 target_user = invite_event.state_key
1137 room_version = await self.store.get_room_version(room_id)
11381128
11391129 content["membership"] = Membership.LEAVE
11401130
1141 # the auth events for the new event are the same as that of the invite, plus
1142 # the invite itself.
1143 #
1144 # the prev_events are just the invite.
1145 invite_hash = invite_event.event_id # type: Union[str, Tuple]
1146 if room_version.event_format == EventFormatVersions.V1:
1147 alg, h = compute_event_reference_hash(invite_event)
1148 invite_hash = (invite_event.event_id, {alg: encode_base64(h)})
1149
1150 auth_events = tuple(invite_event.auth_events) + (invite_hash,)
1151 prev_events = (invite_hash,)
1152
1153 # we cap depth of generated events, to ensure that they are not
1154 # rejected by other servers (and so that they can be persisted in
1155 # the db)
1156 depth = min(invite_event.depth + 1, MAX_DEPTH)
1157
11581131 event_dict = {
1159 "depth": depth,
1160 "auth_events": auth_events,
1161 "prev_events": prev_events,
11621132 "type": EventTypes.Member,
11631133 "room_id": room_id,
11641134 "sender": target_user,
11661136 "state_key": target_user,
11671137 }
11681138
1169 event = create_local_event_from_event_dict(
1170 clock=self.clock,
1171 hostname=self.hs.hostname,
1172 signing_key=self.hs.signing_key,
1173 room_version=room_version,
1174 event_dict=event_dict,
1139 # the auth events for the new event are the same as that of the invite, plus
1140 # the invite itself.
1141 #
1142 # the prev_events are just the invite.
1143 prev_event_ids = [invite_event.event_id]
1144 auth_event_ids = invite_event.auth_event_ids() + prev_event_ids
1145
1146 event, context = await self.event_creation_handler.create_event(
1147 requester,
1148 event_dict,
1149 txn_id=txn_id,
1150 prev_event_ids=prev_event_ids,
1151 auth_event_ids=auth_event_ids,
11751152 )
11761153 event.internal_metadata.outlier = True
11771154 event.internal_metadata.out_of_band_membership = True
1178 if txn_id is not None:
1179 event.internal_metadata.txn_id = txn_id
1180 if requester.access_token_id is not None:
1181 event.internal_metadata.token_id = requester.access_token_id
1182
1183 EventValidator().validate_new(event, self.config)
1184
1185 context = await self.state_handler.compute_event_context(event)
1186 context.app_service = requester.app_service
1187 stream_id = await self.event_creation_handler.handle_new_client_event(
1155
1156 result_event = await self.event_creation_handler.handle_new_client_event(
11881157 requester, event, context, extra_users=[UserID.from_string(target_user)],
11891158 )
1190 return event.event_id, stream_id
1159 # we know it was persisted, so must have a stream ordering
1160 assert result_event.internal_metadata.stream_ordering
1161
1162 return result_event.event_id, result_event.internal_metadata.stream_ordering
11911163
11921164 async def _user_left_room(self, target: UserID, room_id: str) -> None:
11931165 """Implements RoomMemberHandler._user_left_room
4848 # Guard to ensure we only process deltas one at a time
4949 self._is_processing = False
5050
51 if hs.config.stats_enabled:
51 if self.stats_enabled and hs.config.run_background_tasks:
5252 self.notifier.add_replication_callback(self.notify_new_event)
5353
5454 # We kick this off so that we don't have to wait for a change before
1212 # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
1313 # See the License for the specific language governing permissions and
1414 # limitations under the License.
15
1615 import itertools
1716 import logging
1817 from typing import TYPE_CHECKING, Any, Dict, FrozenSet, List, Optional, Set, Tuple
2019 import attr
2120 from prometheus_client import Counter
2221
23 from synapse.api.constants import EventTypes, Membership
22 from synapse.api.constants import AccountDataTypes, EventTypes, Membership
2423 from synapse.api.filtering import FilterCollection
2524 from synapse.events import EventBase
2625 from synapse.logging.context import current_context
8685 class TimelineBatch:
8786 prev_batch = attr.ib(type=StreamToken)
8887 events = attr.ib(type=List[EventBase])
89 limited = attr.ib(bool)
88 limited = attr.ib(type=bool)
9089
9190 def __bool__(self) -> bool:
9291 """Make the result appear empty if there are no updates. This is used
200199 device_lists: List of user_ids whose devices have changed
201200 device_one_time_keys_count: Dict of algorithm to count for one time keys
202201 for this device
202 device_unused_fallback_key_types: List of key types that have an unused fallback
203 key
203204 groups: Group updates, if any
204205 """
205206
212213 to_device = attr.ib(type=List[JsonDict])
213214 device_lists = attr.ib(type=DeviceLists)
214215 device_one_time_keys_count = attr.ib(type=JsonDict)
216 device_unused_fallback_key_types = attr.ib(type=List[str])
215217 groups = attr.ib(type=Optional[GroupsSyncResult])
216218
217219 def __bool__(self) -> bool:
239241 self.presence_handler = hs.get_presence_handler()
240242 self.event_sources = hs.get_event_sources()
241243 self.clock = hs.get_clock()
242 self.response_cache = ResponseCache(hs, "sync")
244 self.response_cache = ResponseCache(
245 hs, "sync"
246 ) # type: ResponseCache[Tuple[Any, ...]]
243247 self.state = hs.get_state_handler()
244248 self.auth = hs.get_auth()
245249 self.storage = hs.get_storage()
456460 recents = []
457461
458462 if not limited or block_all_timeline:
463 prev_batch_token = now_token
464 if recents:
465 room_key = recents[0].internal_metadata.before
466 prev_batch_token = now_token.copy_and_replace("room_key", room_key)
467
459468 return TimelineBatch(
460 events=recents, prev_batch=now_token, limited=False
469 events=recents, prev_batch=prev_batch_token, limited=False
461470 )
462471
463472 filtering_factor = 2
10131022 logger.debug("Fetching OTK data")
10141023 device_id = sync_config.device_id
10151024 one_time_key_counts = {} # type: JsonDict
1025 unused_fallback_key_types = [] # type: List[str]
10161026 if device_id:
10171027 one_time_key_counts = await self.store.count_e2e_one_time_keys(
1028 user_id, device_id
1029 )
1030 unused_fallback_key_types = await self.store.get_e2e_unused_fallback_key_types(
10181031 user_id, device_id
10191032 )
10201033
10401053 device_lists=device_lists,
10411054 groups=sync_result_builder.groups,
10421055 device_one_time_keys_count=one_time_key_counts,
1056 device_unused_fallback_key_types=unused_fallback_key_types,
10431057 next_batch=sync_result_builder.now_token,
10441058 )
10451059
13771391 return set(), set(), set(), set()
13781392
13791393 ignored_account_data = await self.store.get_global_account_data_by_type_for_user(
1380 "m.ignored_user_list", user_id=user_id
1381 )
1382
1394 AccountDataTypes.IGNORED_USER_LIST, user_id=user_id
1395 )
1396
1397 # If there is ignored users account data and it matches the proper type,
1398 # then use it.
1399 ignored_users = frozenset() # type: FrozenSet[str]
13831400 if ignored_account_data:
1384 ignored_users = ignored_account_data.get("ignored_users", {}).keys()
1385 else:
1386 ignored_users = frozenset()
1401 ignored_users_data = ignored_account_data.get("ignored_users", {})
1402 if isinstance(ignored_users_data, dict):
1403 ignored_users = frozenset(ignored_users_data.keys())
13871404
13881405 if since_token:
13891406 room_changes = await self._get_rooms_changed(
14771494 return False
14781495
14791496 async def _get_rooms_changed(
1480 self, sync_result_builder: "SyncResultBuilder", ignored_users: Set[str]
1497 self, sync_result_builder: "SyncResultBuilder", ignored_users: FrozenSet[str]
14811498 ) -> _RoomChanges:
14821499 """Gets the the changes that have happened since the last sync.
14831500 """
16891706 return _RoomChanges(room_entries, invited, newly_joined_rooms, newly_left_rooms)
16901707
16911708 async def _get_all_rooms(
1692 self, sync_result_builder: "SyncResultBuilder", ignored_users: Set[str]
1709 self, sync_result_builder: "SyncResultBuilder", ignored_users: FrozenSet[str]
16931710 ) -> _RoomChanges:
16941711 """Returns entries for all rooms for the user.
16951712
17631780 async def _generate_room_entry(
17641781 self,
17651782 sync_result_builder: "SyncResultBuilder",
1766 ignored_users: Set[str],
1783 ignored_users: FrozenSet[str],
17671784 room_builder: "RoomSyncResultBuilder",
17681785 ephemeral: List[JsonDict],
17691786 tags: Optional[Dict[str, Dict[str, Any]]],
1111 # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
1212 # See the License for the specific language governing permissions and
1313 # limitations under the License.
14
1514 import logging
1615 import random
1716 from collections import namedtuple
1817 from typing import TYPE_CHECKING, List, Set, Tuple
1918
2019 from synapse.api.errors import AuthError, ShadowBanError, SynapseError
20 from synapse.appservice import ApplicationService
2121 from synapse.metrics.background_process_metrics import run_as_background_process
2222 from synapse.replication.tcp.streams import TypingStream
23 from synapse.types import UserID, get_domain_from_id
23 from synapse.types import JsonDict, UserID, get_domain_from_id
2424 from synapse.util.caches.stream_change_cache import StreamChangeCache
2525 from synapse.util.metrics import Measure
2626 from synapse.util.wheel_timer import WheelTimer
429429 "content": {"user_ids": list(typing)},
430430 }
431431
432 async def get_new_events_as(
433 self, from_key: int, service: ApplicationService
434 ) -> Tuple[List[JsonDict], int]:
435 """Returns a set of new typing events that an appservice
436 may be interested in.
437
438 Args:
439 from_key: the stream position at which events should be fetched from
440 service: The appservice which may be interested
441 """
442 with Measure(self.clock, "typing.get_new_events_as"):
443 from_key = int(from_key)
444 handler = self.get_typing_handler()
445
446 events = []
447 for room_id in handler._room_serials.keys():
448 if handler._room_serials[room_id] <= from_key:
449 continue
450 if not await service.matches_user_in_member_list(
451 room_id, handler.store
452 ):
453 continue
454
455 events.append(self._make_event_for(room_id))
456
457 return (events, handler._latest_room_serial)
458
432459 async def get_new_events(self, from_key, room_ids, **kwargs):
433460 with Measure(self.clock, "typing.get_new_events"):
434461 from_key = int(from_key)
142142
143143 threepid_creds = authdict["threepid_creds"]
144144
145 identity_handler = self.hs.get_handlers().identity_handler
145 identity_handler = self.hs.get_identity_handler()
146146
147147 logger.info("Getting validated threepid. threepidcreds: %r", (threepid_creds,))
148148
3434 from twisted.web.static import File, NoRangeStaticProducer
3535 from twisted.web.util import redirectTo
3636
37 import synapse.events
38 import synapse.metrics
3937 from synapse.api.errors import (
4038 CodeMessageException,
4139 Codes,
256254 if isinstance(raw_callback_return, (defer.Deferred, types.CoroutineType)):
257255 callback_return = await raw_callback_return
258256 else:
259 callback_return = raw_callback_return
257 callback_return = raw_callback_return # type: ignore
260258
261259 return callback_return
262260
405403 if isinstance(raw_callback_return, (defer.Deferred, types.CoroutineType)):
406404 callback_return = await raw_callback_return
407405 else:
408 callback_return = raw_callback_return
406 callback_return = raw_callback_return # type: ignore
409407
410408 return callback_return
411409
619617 if pretty_print:
620618 encoder = iterencode_pretty_printed_json
621619 else:
622 if canonical_json or synapse.events.USE_FROZEN_DICTS:
620 if canonical_json:
623621 encoder = iterencode_canonical_json
624622 else:
625623 encoder = _encode_json_bytes
650648 Returns:
651649 twisted.web.server.NOT_DONE_YET if the request is still active.
652650 """
651 if request._disconnected:
652 logger.warning(
653 "Not sending response to request %s, already disconnected.", request
654 )
655 return
653656
654657 request.setResponseCode(code)
655658 request.setHeader(b"Content-Type", b"application/json")
0 # -*- coding: utf-8 -*-
1 # Copyright 2020 The Matrix.org Foundation C.I.C.
2 #
3 # Licensed under the Apache License, Version 2.0 (the "License");
4 # you may not use this file except in compliance with the License.
5 # You may obtain a copy of the License at
6 #
7 # http://www.apache.org/licenses/LICENSE-2.0
8 #
9 # Unless required by applicable law or agreed to in writing, software
10 # distributed under the License is distributed on an "AS IS" BASIS,
11 # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 # See the License for the specific language governing permissions and
13 # limitations under the License.
14
15 import sys
16 import traceback
17 from collections import deque
18 from ipaddress import IPv4Address, IPv6Address, ip_address
19 from math import floor
20 from typing import Callable, Optional
21
22 import attr
23 from zope.interface import implementer
24
25 from twisted.application.internet import ClientService
26 from twisted.internet.defer import Deferred
27 from twisted.internet.endpoints import (
28 HostnameEndpoint,
29 TCP4ClientEndpoint,
30 TCP6ClientEndpoint,
31 )
32 from twisted.internet.interfaces import IPushProducer, ITransport
33 from twisted.internet.protocol import Factory, Protocol
34 from twisted.logger import ILogObserver, Logger, LogLevel
35
36
37 @attr.s
38 @implementer(IPushProducer)
39 class LogProducer:
40 """
41 An IPushProducer that writes logs from its buffer to its transport when it
42 is resumed.
43
44 Args:
45 buffer: Log buffer to read logs from.
46 transport: Transport to write to.
47 format_event: A callable to format the log entry to a string.
48 """
49
50 transport = attr.ib(type=ITransport)
51 format_event = attr.ib(type=Callable[[dict], str])
52 _buffer = attr.ib(type=deque)
53 _paused = attr.ib(default=False, type=bool, init=False)
54
55 def pauseProducing(self):
56 self._paused = True
57
58 def stopProducing(self):
59 self._paused = True
60 self._buffer = deque()
61
62 def resumeProducing(self):
63 self._paused = False
64
65 while self._paused is False and (self._buffer and self.transport.connected):
66 try:
67 # Request the next event and format it.
68 event = self._buffer.popleft()
69 msg = self.format_event(event)
70
71 # Send it as a new line over the transport.
72 self.transport.write(msg.encode("utf8"))
73 except Exception:
74 # Something has gone wrong writing to the transport -- log it
75 # and break out of the while.
76 traceback.print_exc(file=sys.__stderr__)
77 break
78
79
80 @attr.s
81 @implementer(ILogObserver)
82 class TCPLogObserver:
83 """
84 An IObserver that writes JSON logs to a TCP target.
85
86 Args:
87 hs (HomeServer): The homeserver that is being logged for.
88 host: The host of the logging target.
89 port: The logging target's port.
90 format_event: A callable to format the log entry to a string.
91 maximum_buffer: The maximum buffer size.
92 """
93
94 hs = attr.ib()
95 host = attr.ib(type=str)
96 port = attr.ib(type=int)
97 format_event = attr.ib(type=Callable[[dict], str])
98 maximum_buffer = attr.ib(type=int)
99 _buffer = attr.ib(default=attr.Factory(deque), type=deque)
100 _connection_waiter = attr.ib(default=None, type=Optional[Deferred])
101 _logger = attr.ib(default=attr.Factory(Logger))
102 _producer = attr.ib(default=None, type=Optional[LogProducer])
103
104 def start(self) -> None:
105
106 # Connect without DNS lookups if it's a direct IP.
107 try:
108 ip = ip_address(self.host)
109 if isinstance(ip, IPv4Address):
110 endpoint = TCP4ClientEndpoint(
111 self.hs.get_reactor(), self.host, self.port
112 )
113 elif isinstance(ip, IPv6Address):
114 endpoint = TCP6ClientEndpoint(
115 self.hs.get_reactor(), self.host, self.port
116 )
117 else:
118 raise ValueError("Unknown IP address provided: %s" % (self.host,))
119 except ValueError:
120 endpoint = HostnameEndpoint(self.hs.get_reactor(), self.host, self.port)
121
122 factory = Factory.forProtocol(Protocol)
123 self._service = ClientService(endpoint, factory, clock=self.hs.get_reactor())
124 self._service.startService()
125 self._connect()
126
127 def stop(self):
128 self._service.stopService()
129
130 def _connect(self) -> None:
131 """
132 Triggers an attempt to connect then write to the remote if not already writing.
133 """
134 if self._connection_waiter:
135 return
136
137 self._connection_waiter = self._service.whenConnected(failAfterFailures=1)
138
139 @self._connection_waiter.addErrback
140 def fail(r):
141 r.printTraceback(file=sys.__stderr__)
142 self._connection_waiter = None
143 self._connect()
144
145 @self._connection_waiter.addCallback
146 def writer(r):
147 # We have a connection. If we already have a producer, and its
148 # transport is the same, just trigger a resumeProducing.
149 if self._producer and r.transport is self._producer.transport:
150 self._producer.resumeProducing()
151 self._connection_waiter = None
152 return
153
154 # If the producer is still producing, stop it.
155 if self._producer:
156 self._producer.stopProducing()
157
158 # Make a new producer and start it.
159 self._producer = LogProducer(
160 buffer=self._buffer,
161 transport=r.transport,
162 format_event=self.format_event,
163 )
164 r.transport.registerProducer(self._producer, True)
165 self._producer.resumeProducing()
166 self._connection_waiter = None
167
168 def _handle_pressure(self) -> None:
169 """
170 Handle backpressure by shedding events.
171
172 The buffer will, in this order, until the buffer is below the maximum:
173 - Shed DEBUG events
174 - Shed INFO events
175 - Shed the middle 50% of the events.
176 """
177 if len(self._buffer) <= self.maximum_buffer:
178 return
179
180 # Strip out DEBUGs
181 self._buffer = deque(
182 filter(lambda event: event["log_level"] != LogLevel.debug, self._buffer)
183 )
184
185 if len(self._buffer) <= self.maximum_buffer:
186 return
187
188 # Strip out INFOs
189 self._buffer = deque(
190 filter(lambda event: event["log_level"] != LogLevel.info, self._buffer)
191 )
192
193 if len(self._buffer) <= self.maximum_buffer:
194 return
195
196 # Cut the middle entries out
197 buffer_split = floor(self.maximum_buffer / 2)
198
199 old_buffer = self._buffer
200 self._buffer = deque()
201
202 for i in range(buffer_split):
203 self._buffer.append(old_buffer.popleft())
204
205 end_buffer = []
206 for i in range(buffer_split):
207 end_buffer.append(old_buffer.pop())
208
209 self._buffer.extend(reversed(end_buffer))
210
211 def __call__(self, event: dict) -> None:
212 self._buffer.append(event)
213
214 # Handle backpressure, if it exists.
215 try:
216 self._handle_pressure()
217 except Exception:
218 # If handling backpressure fails,clear the buffer and log the
219 # exception.
220 self._buffer.clear()
221 self._logger.failure("Failed clearing backpressure")
222
223 # Try and write immediately.
224 self._connect()
1111 # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
1212 # See the License for the specific language governing permissions and
1313 # limitations under the License.
14
1514 import logging
1615 import os.path
1716 import sys
8887 context = current_context()
8988
9089 # Copy the context information to the log event.
91 if context is not None:
92 context.copy_to_twisted_log_entry(event)
93 else:
94 # If there's no logging context, not even the root one, we might be
95 # starting up or it might be from non-Synapse code. Log it as if it
96 # came from the root logger.
97 event["request"] = None
98 event["scope"] = None
90 context.copy_to_twisted_log_entry(event)
9991
10092 self.observer(event)
10193
1717 """
1818
1919 import json
20 import sys
21 import traceback
22 from collections import deque
23 from ipaddress import IPv4Address, IPv6Address, ip_address
24 from math import floor
25 from typing import IO, Optional
20 from typing import IO
2621
27 import attr
28 from zope.interface import implementer
22 from twisted.logger import FileLogObserver
2923
30 from twisted.application.internet import ClientService
31 from twisted.internet.defer import Deferred
32 from twisted.internet.endpoints import (
33 HostnameEndpoint,
34 TCP4ClientEndpoint,
35 TCP6ClientEndpoint,
36 )
37 from twisted.internet.interfaces import IPushProducer, ITransport
38 from twisted.internet.protocol import Factory, Protocol
39 from twisted.logger import FileLogObserver, ILogObserver, Logger
24 from synapse.logging._remote import TCPLogObserver
4025
4126 _encoder = json.JSONEncoder(ensure_ascii=False, separators=(",", ":"))
4227
149134 return FileLogObserver(outFile, formatEvent)
150135
151136
152 @attr.s
153 @implementer(IPushProducer)
154 class LogProducer:
137 def TerseJSONToTCPLogObserver(
138 hs, host: str, port: int, metadata: dict, maximum_buffer: int
139 ) -> FileLogObserver:
155140 """
156 An IPushProducer that writes logs from its buffer to its transport when it
157 is resumed.
158
159 Args:
160 buffer: Log buffer to read logs from.
161 transport: Transport to write to.
162 """
163
164 transport = attr.ib(type=ITransport)
165 _buffer = attr.ib(type=deque)
166 _paused = attr.ib(default=False, type=bool, init=False)
167
168 def pauseProducing(self):
169 self._paused = True
170
171 def stopProducing(self):
172 self._paused = True
173 self._buffer = deque()
174
175 def resumeProducing(self):
176 self._paused = False
177
178 while self._paused is False and (self._buffer and self.transport.connected):
179 try:
180 event = self._buffer.popleft()
181 self.transport.write(_encoder.encode(event).encode("utf8"))
182 self.transport.write(b"\n")
183 except Exception:
184 # Something has gone wrong writing to the transport -- log it
185 # and break out of the while.
186 traceback.print_exc(file=sys.__stderr__)
187 break
188
189
190 @attr.s
191 @implementer(ILogObserver)
192 class TerseJSONToTCPLogObserver:
193 """
194 An IObserver that writes JSON logs to a TCP target.
141 A log observer that formats events to a flattened JSON representation.
195142
196143 Args:
197144 hs (HomeServer): The homeserver that is being logged for.
198145 host: The host of the logging target.
199146 port: The logging target's port.
200 metadata: Metadata to be added to each log entry.
147 metadata: Metadata to be added to each log object.
148 maximum_buffer: The maximum buffer size.
201149 """
202150
203 hs = attr.ib()
204 host = attr.ib(type=str)
205 port = attr.ib(type=int)
206 metadata = attr.ib(type=dict)
207 maximum_buffer = attr.ib(type=int)
208 _buffer = attr.ib(default=attr.Factory(deque), type=deque)
209 _connection_waiter = attr.ib(default=None, type=Optional[Deferred])
210 _logger = attr.ib(default=attr.Factory(Logger))
211 _producer = attr.ib(default=None, type=Optional[LogProducer])
151 def formatEvent(_event: dict) -> str:
152 flattened = flatten_event(_event, metadata, include_time=True)
153 return _encoder.encode(flattened) + "\n"
212154
213 def start(self) -> None:
214
215 # Connect without DNS lookups if it's a direct IP.
216 try:
217 ip = ip_address(self.host)
218 if isinstance(ip, IPv4Address):
219 endpoint = TCP4ClientEndpoint(
220 self.hs.get_reactor(), self.host, self.port
221 )
222 elif isinstance(ip, IPv6Address):
223 endpoint = TCP6ClientEndpoint(
224 self.hs.get_reactor(), self.host, self.port
225 )
226 except ValueError:
227 endpoint = HostnameEndpoint(self.hs.get_reactor(), self.host, self.port)
228
229 factory = Factory.forProtocol(Protocol)
230 self._service = ClientService(endpoint, factory, clock=self.hs.get_reactor())
231 self._service.startService()
232 self._connect()
233
234 def stop(self):
235 self._service.stopService()
236
237 def _connect(self) -> None:
238 """
239 Triggers an attempt to connect then write to the remote if not already writing.
240 """
241 if self._connection_waiter:
242 return
243
244 self._connection_waiter = self._service.whenConnected(failAfterFailures=1)
245
246 @self._connection_waiter.addErrback
247 def fail(r):
248 r.printTraceback(file=sys.__stderr__)
249 self._connection_waiter = None
250 self._connect()
251
252 @self._connection_waiter.addCallback
253 def writer(r):
254 # We have a connection. If we already have a producer, and its
255 # transport is the same, just trigger a resumeProducing.
256 if self._producer and r.transport is self._producer.transport:
257 self._producer.resumeProducing()
258 self._connection_waiter = None
259 return
260
261 # If the producer is still producing, stop it.
262 if self._producer:
263 self._producer.stopProducing()
264
265 # Make a new producer and start it.
266 self._producer = LogProducer(buffer=self._buffer, transport=r.transport)
267 r.transport.registerProducer(self._producer, True)
268 self._producer.resumeProducing()
269 self._connection_waiter = None
270
271 def _handle_pressure(self) -> None:
272 """
273 Handle backpressure by shedding events.
274
275 The buffer will, in this order, until the buffer is below the maximum:
276 - Shed DEBUG events
277 - Shed INFO events
278 - Shed the middle 50% of the events.
279 """
280 if len(self._buffer) <= self.maximum_buffer:
281 return
282
283 # Strip out DEBUGs
284 self._buffer = deque(
285 filter(lambda event: event["level"] != "DEBUG", self._buffer)
286 )
287
288 if len(self._buffer) <= self.maximum_buffer:
289 return
290
291 # Strip out INFOs
292 self._buffer = deque(
293 filter(lambda event: event["level"] != "INFO", self._buffer)
294 )
295
296 if len(self._buffer) <= self.maximum_buffer:
297 return
298
299 # Cut the middle entries out
300 buffer_split = floor(self.maximum_buffer / 2)
301
302 old_buffer = self._buffer
303 self._buffer = deque()
304
305 for i in range(buffer_split):
306 self._buffer.append(old_buffer.popleft())
307
308 end_buffer = []
309 for i in range(buffer_split):
310 end_buffer.append(old_buffer.pop())
311
312 self._buffer.extend(reversed(end_buffer))
313
314 def __call__(self, event: dict) -> None:
315 flattened = flatten_event(event, self.metadata, include_time=True)
316 self._buffer.append(flattened)
317
318 # Handle backpressure, if it exists.
319 try:
320 self._handle_pressure()
321 except Exception:
322 # If handling backpressure fails,clear the buffer and log the
323 # exception.
324 self._buffer.clear()
325 self._logger.failure("Failed clearing backpressure")
326
327 # Try and write immediately.
328 self._connect()
155 return TCPLogObserver(hs, host, port, formatEvent, maximum_buffer)
2323 from twisted.internet import defer
2424
2525 from synapse.logging.context import LoggingContext, PreserveLoggingContext
26 from synapse.logging.opentracing import start_active_span
2627
2728 if TYPE_CHECKING:
2829 import resource
196197
197198 with BackgroundProcessLoggingContext(desc) as context:
198199 context.request = "%s-%i" % (desc, count)
199
200200 try:
201 result = func(*args, **kwargs)
202
203 if inspect.isawaitable(result):
204 result = await result
205
206 return result
201 with start_active_span(desc, tags={"request_id": context.request}):
202 result = func(*args, **kwargs)
203
204 if inspect.isawaitable(result):
205 result = await result
206
207 return result
207208 except Exception:
208209 logger.exception(
209210 "Background process '%s' threw an exception", desc,
1313 # See the License for the specific language governing permissions and
1414 # limitations under the License.
1515 import logging
16 from typing import TYPE_CHECKING, Iterable, Optional, Tuple
1617
1718 from twisted.internet import defer
1819
20 from synapse.events import EventBase
21 from synapse.http.client import SimpleHttpClient
1922 from synapse.http.site import SynapseRequest
2023 from synapse.logging.context import make_deferred_yieldable, run_in_background
21 from synapse.types import UserID
24 from synapse.storage.state import StateFilter
25 from synapse.types import JsonDict, UserID, create_requester
26
27 if TYPE_CHECKING:
28 from synapse.server import HomeServer
2229
2330 """
2431 This package defines the 'stable' API which can be used by extension modules which
4148 self._store = hs.get_datastore()
4249 self._auth = hs.get_auth()
4350 self._auth_handler = auth_handler
51
52 # We expose these as properties below in order to attach a helpful docstring.
53 self._http_client = hs.get_simple_http_client() # type: SimpleHttpClient
54 self._public_room_list_manager = PublicRoomListManager(hs)
55
56 @property
57 def http_client(self):
58 """Allows making outbound HTTP requests to remote resources.
59
60 An instance of synapse.http.client.SimpleHttpClient
61 """
62 return self._http_client
63
64 @property
65 def public_room_list_manager(self):
66 """Allows adding to, removing from and checking the status of rooms in the
67 public room list.
68
69 An instance of synapse.module_api.PublicRoomListManager
70 """
71 return self._public_room_list_manager
4472
4573 def get_user_by_req(self, req, allow_guest=False):
4674 """Check the access_token provided for a request
265293 await self._auth_handler.complete_sso_login(
266294 registered_user_id, request, client_redirect_url,
267295 )
296
297 @defer.inlineCallbacks
298 def get_state_events_in_room(
299 self, room_id: str, types: Iterable[Tuple[str, Optional[str]]]
300 ) -> defer.Deferred:
301 """Gets current state events for the given room.
302
303 (This is exposed for compatibility with the old SpamCheckerApi. We should
304 probably deprecate it and replace it with an async method in a subclass.)
305
306 Args:
307 room_id: The room ID to get state events in.
308 types: The event type and state key (using None
309 to represent 'any') of the room state to acquire.
310
311 Returns:
312 twisted.internet.defer.Deferred[list(synapse.events.FrozenEvent)]:
313 The filtered state events in the room.
314 """
315 state_ids = yield defer.ensureDeferred(
316 self._store.get_filtered_current_state_ids(
317 room_id=room_id, state_filter=StateFilter.from_types(types)
318 )
319 )
320 state = yield defer.ensureDeferred(self._store.get_events(state_ids.values()))
321 return state.values()
322
323 async def create_and_send_event_into_room(self, event_dict: JsonDict) -> EventBase:
324 """Create and send an event into a room. Membership events are currently not supported.
325
326 Args:
327 event_dict: A dictionary representing the event to send.
328 Required keys are `type`, `room_id`, `sender` and `content`.
329
330 Returns:
331 The event that was sent. If state event deduplication happened, then
332 the previous, duplicate event instead.
333
334 Raises:
335 SynapseError if the event was not allowed.
336 """
337 # Create a requester object
338 requester = create_requester(event_dict["sender"])
339
340 # Create and send the event
341 (
342 event,
343 _,
344 ) = await self._hs.get_event_creation_handler().create_and_send_nonmember_event(
345 requester, event_dict, ratelimit=False, ignore_shadow_ban=True,
346 )
347
348 return event
349
350
351 class PublicRoomListManager:
352 """Contains methods for adding to, removing from and querying whether a room
353 is in the public room list.
354 """
355
356 def __init__(self, hs: "HomeServer"):
357 self._store = hs.get_datastore()
358
359 async def room_is_in_public_room_list(self, room_id: str) -> bool:
360 """Checks whether a room is in the public room list.
361
362 Args:
363 room_id: The ID of the room.
364
365 Returns:
366 Whether the room is in the public room list. Returns False if the room does
367 not exist.
368 """
369 room = await self._store.get_room(room_id)
370 if not room:
371 return False
372
373 return room.get("is_public", False)
374
375 async def add_room_to_public_room_list(self, room_id: str) -> None:
376 """Publishes a room to the public room list.
377
378 Args:
379 room_id: The ID of the room.
380 """
381 await self._store.set_room_is_public(room_id, True)
382
383 async def remove_room_from_public_room_list(self, room_id: str) -> None:
384 """Removes a room from the public room list.
385
386 Args:
387 room_id: The ID of the room.
388 """
389 await self._store.set_room_is_public(room_id, False)
318318 )
319319
320320 if self.federation_sender:
321 self.federation_sender.notify_new_events(max_room_stream_token.stream)
321 self.federation_sender.notify_new_events(max_room_stream_token)
322322
323323 async def _notify_app_services(self, max_room_stream_token: RoomStreamToken):
324324 try:
325325 await self.appservice_handler.notify_interested_services(
326 max_room_stream_token.stream
326 max_room_stream_token
327327 )
328328 except Exception:
329329 logger.exception("Error notifying application services of event")
330330
331 async def _notify_app_services_ephemeral(
332 self,
333 stream_key: str,
334 new_token: Union[int, RoomStreamToken],
335 users: Collection[UserID] = [],
336 ):
337 try:
338 stream_token = None
339 if isinstance(new_token, int):
340 stream_token = new_token
341 await self.appservice_handler.notify_interested_services_ephemeral(
342 stream_key, stream_token, users
343 )
344 except Exception:
345 logger.exception("Error notifying application services of event")
346
331347 async def _notify_pusher_pool(self, max_room_stream_token: RoomStreamToken):
332348 try:
333 await self._pusher_pool.on_new_notifications(max_room_stream_token.stream)
349 await self._pusher_pool.on_new_notifications(max_room_stream_token)
334350 except Exception:
335351 logger.exception("Error pusher pool of event")
336352
338354 self,
339355 stream_key: str,
340356 new_token: Union[int, RoomStreamToken],
341 users: Collection[UserID] = [],
357 users: Collection[Union[str, UserID]] = [],
342358 rooms: Collection[str] = [],
343359 ):
344360 """ Used to inform listeners that something has happened event wise.
365381 logger.exception("Failed to notify listener")
366382
367383 self.notify_replication()
384
385 # Notify appservices
386 run_as_background_process(
387 "_notify_app_services_ephemeral",
388 self._notify_app_services_ephemeral,
389 stream_key,
390 new_token,
391 users,
392 )
368393
369394 def on_new_replication_data(self) -> None:
370395 """Used to inform replication listeners that something has happend
495495 # dedupe when we add callbacks to lru cache nodes, otherwise the number
496496 # of callbacks would grow.
497497 def __call__(self):
498 rules = self.cache.get(self.room_id, None, update_metrics=False)
498 rules = self.cache.get_immediate(self.room_id, None, update_metrics=False)
499499 if rules:
500500 rules.invalidate_all()
1717 from twisted.internet.error import AlreadyCalled, AlreadyCancelled
1818
1919 from synapse.metrics.background_process_metrics import run_as_background_process
20 from synapse.types import RoomStreamToken
2021
2122 logger = logging.getLogger(__name__)
2223
9091 pass
9192 self.timed_call = None
9293
93 def on_new_notifications(self, max_stream_ordering):
94 def on_new_notifications(self, max_token: RoomStreamToken):
95 # We just use the minimum stream ordering and ignore the vector clock
96 # component. This is safe to do as long as we *always* ignore the vector
97 # clock components.
98 max_stream_ordering = max_token.stream
99
94100 if self.max_stream_ordering:
95101 self.max_stream_ordering = max(
96102 max_stream_ordering, self.max_stream_ordering
2222 from synapse.logging import opentracing
2323 from synapse.metrics.background_process_metrics import run_as_background_process
2424 from synapse.push import PusherConfigException
25 from synapse.types import RoomStreamToken
2526
2627 from . import push_rule_evaluator, push_tools
2728
113114 if should_check_for_notifs:
114115 self._start_processing()
115116
116 def on_new_notifications(self, max_stream_ordering):
117 def on_new_notifications(self, max_token: RoomStreamToken):
118 # We just use the minimum stream ordering and ignore the vector clock
119 # component. This is safe to do as long as we *always* ignore the vector
120 # clock components.
121 max_stream_ordering = max_token.stream
122
117123 self.max_stream_ordering = max(
118124 max_stream_ordering, self.max_stream_ordering or 0
119125 )
386386 return ret
387387
388388 async def get_message_vars(self, notif, event, room_state_ids):
389 if event.type != EventTypes.Message:
390 return
389 if event.type != EventTypes.Message and event.type != EventTypes.Encrypted:
390 return None
391391
392392 sender_state_event_id = room_state_ids[("m.room.member", event.sender)]
393393 sender_state_event = await self.store.get_event(sender_state_event_id)
398398 # sender_hash % the number of default images to choose from
399399 sender_hash = string_ordinal_total(event.sender)
400400
401 msgtype = event.content.get("msgtype")
402
403401 ret = {
404 "msgtype": msgtype,
402 "event_type": event.type,
405403 "is_historical": event.event_id != notif["event_id"],
406404 "id": event.event_id,
407405 "ts": event.origin_server_ts,
409407 "sender_avatar_url": sender_avatar_url,
410408 "sender_hash": sender_hash,
411409 }
410
411 # Encrypted messages don't have any additional useful information.
412 if event.type == EventTypes.Encrypted:
413 return ret
414
415 msgtype = event.content.get("msgtype")
416
417 ret["msgtype"] = msgtype
412418
413419 if msgtype == "m.text":
414420 self.add_text_message_vars(ret, event)
1515
1616 import logging
1717 import re
18 from typing import Any, Dict, List, Pattern, Union
18 from typing import Any, Dict, List, Optional, Pattern, Tuple, Union
1919
2020 from synapse.events import EventBase
2121 from synapse.types import UserID
22 from synapse.util.caches import register_cache
2322 from synapse.util.caches.lrucache import LruCache
2423
2524 logger = logging.getLogger(__name__)
173172 # Similar to _glob_matches, but do not treat display_name as a glob.
174173 r = regex_cache.get((display_name, False, True), None)
175174 if not r:
176 r = re.escape(display_name)
177 r = _re_word_boundary(r)
178 r = re.compile(r, flags=re.IGNORECASE)
175 r1 = re.escape(display_name)
176 r1 = _re_word_boundary(r1)
177 r = re.compile(r1, flags=re.IGNORECASE)
179178 regex_cache[(display_name, False, True)] = r
180179
181 return r.search(body)
182
183 def _get_value(self, dotted_key: str) -> str:
180 return bool(r.search(body))
181
182 def _get_value(self, dotted_key: str) -> Optional[str]:
184183 return self._value_cache.get(dotted_key, None)
185184
186185
187186 # Caches (string, is_glob, word_boundary) -> regex for push. See _glob_matches
188 regex_cache = LruCache(50000)
189 register_cache("cache", "regex_push_cache", regex_cache)
187 regex_cache = LruCache(
188 50000, "regex_push_cache"
189 ) # type: LruCache[Tuple[str, bool, bool], Pattern]
190190
191191
192192 def _glob_matches(glob: str, value: str, word_boundary: bool = False) -> bool:
204204 if not r:
205205 r = _glob_to_re(glob, word_boundary)
206206 regex_cache[(glob, True, word_boundary)] = r
207 return r.search(value)
207 return bool(r.search(value))
208208 except re.error:
209209 logger.warning("Failed to parse glob to regex: %r", glob)
210210 return False
2323 from synapse.push.emailpusher import EmailPusher
2424 from synapse.push.httppusher import HttpPusher
2525 from synapse.push.pusher import PusherFactory
26 from synapse.types import RoomStreamToken
2627 from synapse.util.async_helpers import concurrently_execute
2728
2829 if TYPE_CHECKING:
185186 )
186187 await self.remove_pusher(p["app_id"], p["pushkey"], p["user_name"])
187188
188 async def on_new_notifications(self, max_stream_id: int):
189 async def on_new_notifications(self, max_token: RoomStreamToken):
189190 if not self.pushers:
190191 # nothing to do here.
191192 return
193
194 # We just use the minimum stream ordering and ignore the vector clock
195 # component. This is safe to do as long as we *always* ignore the vector
196 # clock components.
197 max_stream_id = max_token.stream
192198
193199 if max_stream_id < self._last_room_stream_id_seen:
194200 # Nothing to do
213219
214220 if u in self.pushers:
215221 for p in self.pushers[u].values():
216 p.on_new_notifications(max_stream_id)
222 p.on_new_notifications(max_token)
217223
218224 except Exception:
219225 logger.exception("Exception in pusher on_new_notifications")
9191 if self.CACHE:
9292 self.response_cache = ResponseCache(
9393 hs, "repl." + self.NAME, timeout_ms=30 * 60 * 1000
94 )
94 ) # type: ResponseCache[str]
9595
9696 # We reserve `instance_name` as a parameter to sending requests, so we
9797 # assert here that sub classes don't try and use the name.
6161 self.store = hs.get_datastore()
6262 self.storage = hs.get_storage()
6363 self.clock = hs.get_clock()
64 self.federation_handler = hs.get_handlers().federation_handler
64 self.federation_handler = hs.get_federation_handler()
6565
6666 @staticmethod
6767 async def _serialize_payload(store, room_id, event_and_contexts, backfilled):
4646 def __init__(self, hs):
4747 super().__init__(hs)
4848
49 self.federation_handler = hs.get_handlers().federation_handler
49 self.federation_handler = hs.get_federation_handler()
5050 self.store = hs.get_datastore()
5151 self.clock = hs.get_clock()
5252
4545 "ratelimit": true,
4646 "extra_users": [],
4747 }
48
49 200 OK
50
51 { "stream_id": 12345, "event_id": "$abcdef..." }
52
53 The returned event ID may not match the sent event if it was deduplicated.
4854 """
4955
5056 NAME = "send_event"
115121 "Got event to send with ID: %s into room: %s", event.event_id, event.room_id
116122 )
117123
118 stream_id = await self.event_creation_handler.persist_and_notify_client_event(
124 event = await self.event_creation_handler.persist_and_notify_client_event(
119125 requester, event, context, ratelimit=ratelimit, extra_users=extra_users
120126 )
121127
122 return 200, {"stream_id": stream_id}
128 return (
129 200,
130 {
131 "stream_id": event.internal_metadata.stream_ordering,
132 "event_id": event.event_id,
133 },
134 )
123135
124136
125137 def register_servlets(hs, http_server):
1414
1515 from synapse.storage.database import DatabasePool
1616 from synapse.storage.databases.main.client_ips import LAST_SEEN_GRANULARITY
17 from synapse.util.caches.descriptors import Cache
17 from synapse.util.caches.lrucache import LruCache
1818
1919 from ._base import BaseSlavedStore
2020
2323 def __init__(self, database: DatabasePool, db_conn, hs):
2424 super().__init__(database, db_conn, hs)
2525
26 self.client_ip_last_seen = Cache(
27 name="client_ip_last_seen", keylen=4, max_entries=50000
28 )
26 self.client_ip_last_seen = LruCache(
27 cache_name="client_ip_last_seen", keylen=4, max_size=50000
28 ) # type: LruCache[tuple, int]
2929
3030 async def insert_client_ip(self, user_id, access_token, ip, user_agent, device_id):
3131 now = int(self._clock.time_msec())
4040 if last_seen is not None and (now - last_seen) < LAST_SEEN_GRANULARITY:
4141 return
4242
43 self.client_ip_last_seen.prefill(key, now)
43 self.client_ip_last_seen.set(key, now)
4444
4545 self.hs.get_tcp_replication().send_user_ip(
4646 user_id, access_token, ip, user_agent, device_id, now
190190 async def on_position(self, stream_name: str, instance_name: str, token: int):
191191 self.store.process_replication_rows(stream_name, instance_name, token, [])
192192
193 # We poke the generic "replication" notifier to wake anything up that
194 # may be streaming.
195 self.notifier.notify_replication()
196
193197 def on_remote_server_up(self, server: str):
194198 """Called when get a new REMOTE_SERVER_UP command."""
195199
140140
141141
142142 class PositionCommand(Command):
143 """Sent by the server to tell the client the stream position without
144 needing to send an RDATA.
145
146 Format::
147
148 POSITION <stream_name> <instance_name> <token>
149
150 On receipt of a POSITION command clients should check if they have missed
151 any updates, and if so then fetch them out of band.
143 """Sent by an instance to tell others the stream position without needing to
144 send an RDATA.
145
146 Two tokens are sent, the new position and the last position sent by the
147 instance (in an RDATA or other POSITION). The tokens are chosen so that *no*
148 rows were written by the instance between the `prev_token` and `new_token`.
149 (If an instance hasn't sent a position before then the new position can be
150 used for both.)
151
152 Format::
153
154 POSITION <stream_name> <instance_name> <prev_token> <new_token>
155
156 On receipt of a POSITION command instances should check if they have missed
157 any updates, and if so then fetch them out of band. Instances can check this
158 by comparing their view of the current token for the sending instance with
159 the included `prev_token`.
152160
153161 The `<instance_name>` is the process that sent the command and is the source
154162 of the stream.
156164
157165 NAME = "POSITION"
158166
159 def __init__(self, stream_name, instance_name, token):
167 def __init__(self, stream_name, instance_name, prev_token, new_token):
160168 self.stream_name = stream_name
161169 self.instance_name = instance_name
162 self.token = token
163
164 @classmethod
165 def from_line(cls, line):
166 stream_name, instance_name, token = line.split(" ", 2)
167 return cls(stream_name, instance_name, int(token))
168
169 def to_line(self):
170 return " ".join((self.stream_name, self.instance_name, str(self.token)))
170 self.prev_token = prev_token
171 self.new_token = new_token
172
173 @classmethod
174 def from_line(cls, line):
175 stream_name, instance_name, prev_token, new_token = line.split(" ", 3)
176 return cls(stream_name, instance_name, int(prev_token), int(new_token))
177
178 def to_line(self):
179 return " ".join(
180 (
181 self.stream_name,
182 self.instance_name,
183 str(self.prev_token),
184 str(self.new_token),
185 )
186 )
171187
172188
173189 class ErrorCommand(_SimpleCommand):
100100 self._streams_to_replicate = [] # type: List[Stream]
101101
102102 for stream in self._streams.values():
103 if stream.NAME == CachesStream.NAME:
104 # All workers can write to the cache invalidation stream.
103 if hs.config.redis.redis_enabled and stream.NAME == CachesStream.NAME:
104 # All workers can write to the cache invalidation stream when
105 # using redis.
105106 self._streams_to_replicate.append(stream)
106107 continue
107108
250251 using TCP.
251252 """
252253 if hs.config.redis.redis_enabled:
253 import txredisapi
254
255254 from synapse.replication.tcp.redis import (
256255 RedisDirectTcpReplicationClientFactory,
256 lazyConnection,
257257 )
258258
259259 logger.info(
270270 # connection after SUBSCRIBE is called).
271271
272272 # First create the connection for sending commands.
273 outbound_redis_connection = txredisapi.lazyConnection(
273 outbound_redis_connection = lazyConnection(
274 reactor=hs.get_reactor(),
274275 host=hs.config.redis_host,
275276 port=hs.config.redis_port,
276277 password=hs.config.redis.redis_password,
312313 # We respond with current position of all streams this instance
313314 # replicates.
314315 for stream in self.get_streams_to_replicate():
316 # Note that we use the current token as the prev token here (rather
317 # than stream.last_token), as we can't be sure that there have been
318 # no rows written between last token and the current token (since we
319 # might be racing with the replication sending bg process).
320 current_token = stream.current_token(self._instance_name)
315321 self.send_command(
316322 PositionCommand(
317 stream.NAME,
318 self._instance_name,
319 stream.current_token(self._instance_name),
323 stream.NAME, self._instance_name, current_token, current_token,
320324 )
321325 )
322326
510514 # If the position token matches our current token then we're up to
511515 # date and there's nothing to do. Otherwise, fetch all updates
512516 # between then and now.
513 missing_updates = cmd.token != current_token
517 missing_updates = cmd.prev_token != current_token
514518 while missing_updates:
515519 logger.info(
516520 "Fetching replication rows for '%s' between %i and %i",
517521 stream_name,
518522 current_token,
519 cmd.token,
523 cmd.new_token,
520524 )
521525 (updates, current_token, missing_updates) = await stream.get_updates_since(
522 cmd.instance_name, current_token, cmd.token
526 cmd.instance_name, current_token, cmd.new_token
523527 )
524528
525529 # TODO: add some tests for this
535539 [stream.parse_row(row) for row in rows],
536540 )
537541
538 logger.info("Caught up with stream '%s' to %i", stream_name, cmd.token)
542 logger.info("Caught up with stream '%s' to %i", stream_name, cmd.new_token)
539543
540544 # We've now caught up to position sent to us, notify handler.
541545 await self._replication_data_handler.on_position(
542 cmd.stream_name, cmd.instance_name, cmd.token
546 cmd.stream_name, cmd.instance_name, cmd.new_token
543547 )
544548
545549 self._streams_by_connection.setdefault(conn, set()).add(stream_name)
5050 import logging
5151 import struct
5252 from inspect import isawaitable
53 from typing import TYPE_CHECKING, List
53 from typing import TYPE_CHECKING, List, Optional
5454
5555 from prometheus_client import Counter
5656
57 from twisted.internet import task
5758 from twisted.protocols.basic import LineOnlyReceiver
5859 from twisted.python.failure import Failure
5960
151152
152153 self.last_received_command = self.clock.time_msec()
153154 self.last_sent_command = 0
154 self.time_we_closed = None # When we requested the connection be closed
155
156 self.received_ping = False # Have we reecived a ping from the other side
155 # When we requested the connection be closed
156 self.time_we_closed = None # type: Optional[int]
157
158 self.received_ping = False # Have we received a ping from the other side
157159
158160 self.state = ConnectionStates.CONNECTING
159161
164166 self.pending_commands = [] # type: List[Command]
165167
166168 # The LoopingCall for sending pings.
167 self._send_ping_loop = None
169 self._send_ping_loop = None # type: Optional[task.LoopingCall]
168170
169171 # a logcontext which we use for processing incoming commands. We declare it as a
170172 # background process so that the CPU stats get reported to prometheus.
1414
1515 import logging
1616 from inspect import isawaitable
17 from typing import TYPE_CHECKING
17 from typing import TYPE_CHECKING, Optional
1818
1919 import txredisapi
2020
227227 p.password = self.password
228228
229229 return p
230
231
232 def lazyConnection(
233 reactor,
234 host: str = "localhost",
235 port: int = 6379,
236 dbid: Optional[int] = None,
237 reconnect: bool = True,
238 charset: str = "utf-8",
239 password: Optional[str] = None,
240 connectTimeout: Optional[int] = None,
241 replyTimeout: Optional[int] = None,
242 convertNumbers: bool = True,
243 ) -> txredisapi.RedisProtocol:
244 """Equivalent to `txredisapi.lazyConnection`, except allows specifying a
245 reactor.
246 """
247
248 isLazy = True
249 poolsize = 1
250
251 uuid = "%s:%d" % (host, port)
252 factory = txredisapi.RedisFactory(
253 uuid,
254 dbid,
255 poolsize,
256 isLazy,
257 txredisapi.ConnectionHandler,
258 charset,
259 password,
260 replyTimeout,
261 convertNumbers,
262 )
263 factory.continueTrying = reconnect
264 for x in range(poolsize):
265 reactor.connectTCP(host, port, factory, connectTimeout)
266
267 return factory.handler
2222 from twisted.internet.protocol import Factory
2323
2424 from synapse.metrics.background_process_metrics import run_as_background_process
25 from synapse.replication.tcp.commands import PositionCommand
2526 from synapse.replication.tcp.protocol import ServerReplicationStreamProtocol
27 from synapse.replication.tcp.streams import EventsStream
2628 from synapse.util.metrics import Measure
2729
2830 stream_updates_counter = Counter(
8385 # Set of streams to replicate.
8486 self.streams = self.command_handler.get_streams_to_replicate()
8587
88 # If we have streams then we must have redis enabled or on master
89 assert (
90 not self.streams
91 or hs.config.redis.redis_enabled
92 or not hs.config.worker.worker_app
93 )
94
95 # If we are replicating an event stream we want to periodically check if
96 # we should send updated POSITIONs. We do this as a looping call rather
97 # explicitly poking when the position advances (without new data to
98 # replicate) to reduce replication traffic (otherwise each writer would
99 # likely send a POSITION for each new event received over replication).
100 #
101 # Note that if the position hasn't advanced then we won't send anything.
102 if any(EventsStream.NAME == s.NAME for s in self.streams):
103 self.clock.looping_call(self.on_notifier_poke, 1000)
104
86105 def on_notifier_poke(self):
87106 """Checks if there is actually any new data and sends it to the
88107 connections if there are.
90109 This should get called each time new data is available, even if it
91110 is currently being executed, so that nothing gets missed
92111 """
93 if not self.command_handler.connected():
112 if not self.command_handler.connected() or not self.streams:
94113 # Don't bother if nothing is listening. We still need to advance
95114 # the stream tokens otherwise they'll fall behind forever
96115 for stream in self.streams:
134153 await self.clock.sleep(
135154 self._replication_torture_level / 1000.0
136155 )
156
157 last_token = stream.last_token
137158
138159 logger.debug(
139160 "Getting stream: %s: %s -> %s",
158179 )
159180 stream_updates_counter.labels(stream.NAME).inc(len(updates))
160181
182 else:
183 # The token has advanced but there is no data to
184 # send, so we send a `POSITION` to inform other
185 # workers of the updated position.
186 if stream.NAME == EventsStream.NAME:
187 # XXX: We only do this for the EventStream as it
188 # turns out that e.g. account data streams share
189 # their "current token" with each other, meaning
190 # that it is *not* safe to send a POSITION.
191 logger.info(
192 "Sending position: %s -> %s",
193 stream.NAME,
194 current_token,
195 )
196 self.command_handler.send_command(
197 PositionCommand(
198 stream.NAME,
199 self._instance_name,
200 last_token,
201 current_token,
202 )
203 )
204 continue
205
161206 # Some streams return multiple rows with the same stream IDs,
162207 # we need to make sure they get sent out in batches. We do
163208 # this by setting the current token to all but the last of
239239 ROW_TYPE = BackfillStreamRow
240240
241241 def __init__(self, hs):
242 store = hs.get_datastore()
243 super().__init__(
244 hs.get_instance_name(),
245 current_token_without_instance(store.get_current_backfill_token),
246 store.get_all_new_backfill_event_rows,
247 )
242 self.store = hs.get_datastore()
243 super().__init__(
244 hs.get_instance_name(),
245 self._current_token,
246 self.store.get_all_new_backfill_event_rows,
247 )
248
249 def _current_token(self, instance_name: str) -> int:
250 # The backfill stream over replication operates on *positive* numbers,
251 # which means we need to negate it.
252 return -self.store._backfill_id_gen.get_current_token_for_writer(instance_name)
248253
249254
250255 class PresenceStream(Stream):
154154 # now we fetch up to that many rows from the events table
155155
156156 event_rows = await self._store.get_all_new_forward_event_rows(
157 from_token, current_token, target_row_count
157 instance_name, from_token, current_token, target_row_count
158158 ) # type: List[Tuple]
159159
160160 # we rely on get_all_new_forward_event_rows strictly honouring the limit, so
179179 upper_limit,
180180 state_rows_limited,
181181 ) = await self._store.get_all_updated_current_state_deltas(
182 from_token, upper_limit, target_row_count
182 instance_name, from_token, upper_limit, target_row_count
183183 )
184184
185185 limited = limited or state_rows_limited
188188 # not to bother with the limit.
189189
190190 ex_outliers_rows = await self._store.get_ex_outlier_stream_rows(
191 from_token, upper_limit
191 instance_name, from_token, upper_limit
192192 ) # type: List[Tuple]
193193
194194 # we now need to turn the raw database rows returned into tuples suitable
0 {% for message in notif.messages %}
0 {%- for message in notif.messages %}
11 <tr class="{{ "historical_message" if message.is_historical else "message" }}">
22 <td class="sender_avatar">
3 {% if loop.index0 == 0 or notif.messages[loop.index0 - 1].sender_name != notif.messages[loop.index0].sender_name %}
4 {% if message.sender_avatar_url %}
3 {%- if loop.index0 == 0 or notif.messages[loop.index0 - 1].sender_name != notif.messages[loop.index0].sender_name %}
4 {%- if message.sender_avatar_url %}
55 <img alt="" class="sender_avatar" src="{{ message.sender_avatar_url|mxc_to_http(32,32) }}" />
6 {% else %}
7 {% if message.sender_hash % 3 == 0 %}
6 {%- else %}
7 {%- if message.sender_hash % 3 == 0 %}
88 <img class="sender_avatar" src="https://riot.im/img/external/avatar-1.png" />
9 {% elif message.sender_hash % 3 == 1 %}
9 {%- elif message.sender_hash % 3 == 1 %}
1010 <img class="sender_avatar" src="https://riot.im/img/external/avatar-2.png" />
11 {% else %}
11 {%- else %}
1212 <img class="sender_avatar" src="https://riot.im/img/external/avatar-3.png" />
13 {% endif %}
14 {% endif %}
15 {% endif %}
13 {%- endif %}
14 {%- endif %}
15 {%- endif %}
1616 </td>
1717 <td class="message_contents">
18 {% if loop.index0 == 0 or notif.messages[loop.index0 - 1].sender_name != notif.messages[loop.index0].sender_name %}
19 <div class="sender_name">{% if message.msgtype == "m.emote" %}*{% endif %} {{ message.sender_name }}</div>
20 {% endif %}
18 {%- if loop.index0 == 0 or notif.messages[loop.index0 - 1].sender_name != notif.messages[loop.index0].sender_name %}
19 <div class="sender_name">{%- if message.msgtype == "m.emote" %}*{%- endif %} {{ message.sender_name }}</div>
20 {%- endif %}
2121 <div class="message_body">
22 {% if message.msgtype == "m.text" %}
23 {{ message.body_text_html }}
24 {% elif message.msgtype == "m.emote" %}
25 {{ message.body_text_html }}
26 {% elif message.msgtype == "m.notice" %}
27 {{ message.body_text_html }}
28 {% elif message.msgtype == "m.image" %}
29 <img src="{{ message.image_url|mxc_to_http(640, 480, scale) }}" />
30 {% elif message.msgtype == "m.file" %}
31 <span class="filename">{{ message.body_text_plain }}</span>
32 {% endif %}
22 {%- if message.event_type == "m.room.encrypted" %}
23 An encrypted message.
24 {%- elif message.event_type == "m.room.message" %}
25 {%- if message.msgtype == "m.text" %}
26 {{ message.body_text_html }}
27 {%- elif message.msgtype == "m.emote" %}
28 {{ message.body_text_html }}
29 {%- elif message.msgtype == "m.notice" %}
30 {{ message.body_text_html }}
31 {%- elif message.msgtype == "m.image" %}
32 <img src="{{ message.image_url|mxc_to_http(640, 480, scale) }}" />
33 {%- elif message.msgtype == "m.file" %}
34 <span class="filename">{{ message.body_text_plain }}</span>
35 {%- else %}
36 A message with unrecognised content.
37 {%- endif %}
38 {%- endif %}
3339 </div>
3440 </td>
3541 <td class="message_time">{{ message.ts|format_ts("%H:%M") }}</td>
3642 </tr>
37 {% endfor %}
43 {%- endfor %}
3844 <tr class="notif_link">
3945 <td></td>
4046 <td>
0 {% for message in notif.messages %}
1 {% if message.msgtype == "m.emote" %}* {% endif %}{{ message.sender_name }} ({{ message.ts|format_ts("%H:%M") }})
2 {% if message.msgtype == "m.text" %}
0 {%- for message in notif.messages %}
1 {%- if message.event_type == "m.room.encrypted" %}
2 An encrypted message.
3 {%- elif message.event_type == "m.room.message" %}
4 {%- if message.msgtype == "m.emote" %}* {%- endif %}{{ message.sender_name }} ({{ message.ts|format_ts("%H:%M") }})
5 {%- if message.msgtype == "m.text" %}
36 {{ message.body_text_plain }}
4 {% elif message.msgtype == "m.emote" %}
7 {%- elif message.msgtype == "m.emote" %}
58 {{ message.body_text_plain }}
6 {% elif message.msgtype == "m.notice" %}
9 {%- elif message.msgtype == "m.notice" %}
710 {{ message.body_text_plain }}
8 {% elif message.msgtype == "m.image" %}
11 {%- elif message.msgtype == "m.image" %}
912 {{ message.body_text_plain }}
10 {% elif message.msgtype == "m.file" %}
13 {%- elif message.msgtype == "m.file" %}
1114 {{ message.body_text_plain }}
12 {% endif %}
13 {% endfor %}
15 {%- else %}
16 A message with unrecognised content.
17 {%- endif %}
18 {%- endif %}
19 {%- endfor %}
1420
1521 View {{ room.title }} at {{ notif.link }}
11 <html lang="en">
22 <head>
33 <style type="text/css">
4 {% include 'mail.css' without context %}
5 {% include "mail-%s.css" % app_name ignore missing without context %}
4 {%- include 'mail.css' without context %}
5 {%- include "mail-%s.css" % app_name ignore missing without context %}
66 </style>
77 </head>
88 <body>
1717 <div class="summarytext">{{ summary_text }}</div>
1818 </td>
1919 <td class="logo">
20 {% if app_name == "Riot" %}
20 {%- if app_name == "Riot" %}
2121 <img src="http://riot.im/img/external/riot-logo-email.png" width="83" height="83" alt="[Riot]"/>
22 {% elif app_name == "Vector" %}
22 {%- elif app_name == "Vector" %}
2323 <img src="http://matrix.org/img/vector-logo-email.png" width="64" height="83" alt="[Vector]"/>
24 {% elif app_name == "Element" %}
24 {%- elif app_name == "Element" %}
2525 <img src="https://static.element.io/images/email-logo.png" width="83" height="83" alt="[Element]"/>
26 {% else %}
26 {%- else %}
2727 <img src="http://matrix.org/img/matrix-120x51.png" width="120" height="51" alt="[matrix]"/>
28 {% endif %}
28 {%- endif %}
2929 </td>
3030 </tr>
3131 </table>
32 {% for room in rooms %}
33 {% include 'room.html' with context %}
34 {% endfor %}
32 {%- for room in rooms %}
33 {%- include 'room.html' with context %}
34 {%- endfor %}
3535 <div class="footer">
3636 <a href="{{ unsubscribe_link }}">Unsubscribe</a>
3737 <br/>
4040 Sending email at {{ reason.now|format_ts("%c") }} due to activity in room {{ reason.room_name }} because
4141 an event was received at {{ reason.received_at|format_ts("%c") }}
4242 which is more than {{ "%.1f"|format(reason.delay_before_mail_ms / (60*1000)) }} ({{ reason.delay_before_mail_ms }}) mins ago,
43 {% if reason.last_sent_ts %}
43 {%- if reason.last_sent_ts %}
4444 and the last time we sent a mail for this room was {{ reason.last_sent_ts|format_ts("%c") }},
4545 which is more than {{ "%.1f"|format(reason.throttle_ms / (60*1000)) }} (current throttle_ms) mins ago.
46 {% else %}
46 {%- else %}
4747 and we don't have a last time we sent a mail for this room.
48 {% endif %}
48 {%- endif %}
4949 </div>
5050 </div>
5151 </td>
11
22 {{ summary_text }}
33
4 {% for room in rooms %}
5 {% include 'room.txt' with context %}
6 {% endfor %}
4 {%- for room in rooms %}
5 {%- include 'room.txt' with context %}
6 {%- endfor %}
77
88 You can disable these notifications at {{ unsubscribe_link }}
99
00 <table class="room">
11 <tr class="room_header">
22 <td class="room_avatar">
3 {% if room.avatar_url %}
3 {%- if room.avatar_url %}
44 <img alt="" src="{{ room.avatar_url|mxc_to_http(48,48) }}" />
5 {% else %}
6 {% if room.hash % 3 == 0 %}
5 {%- else %}
6 {%- if room.hash % 3 == 0 %}
77 <img alt="" src="https://riot.im/img/external/avatar-1.png" />
8 {% elif room.hash % 3 == 1 %}
8 {%- elif room.hash % 3 == 1 %}
99 <img alt="" src="https://riot.im/img/external/avatar-2.png" />
10 {% else %}
10 {%- else %}
1111 <img alt="" src="https://riot.im/img/external/avatar-3.png" />
12 {% endif %}
13 {% endif %}
12 {%- endif %}
13 {%- endif %}
1414 </td>
1515 <td class="room_name" colspan="2">
1616 {{ room.title }}
1717 </td>
1818 </tr>
19 {% if room.invite %}
19 {%- if room.invite %}
2020 <tr>
2121 <td></td>
2222 <td>
2424 </td>
2525 <td></td>
2626 </tr>
27 {% else %}
28 {% for notif in room.notifs %}
29 {% include 'notif.html' with context %}
30 {% endfor %}
31 {% endif %}
27 {%- else %}
28 {%- for notif in room.notifs %}
29 {%- include 'notif.html' with context %}
30 {%- endfor %}
31 {%- endif %}
3232 </table>
00 {{ room.title }}
11
2 {% if room.invite %}
2 {%- if room.invite %}
33 You've been invited, join at {{ room.link }}
4 {% else %}
5 {% for notif in room.notifs %}
6 {% include 'notif.txt' with context %}
7 {% endfor %}
8 {% endif %}
4 {%- else %}
5 {%- for notif in room.notifs %}
6 {%- include 'notif.txt' with context %}
7 {%- endfor %}
8 {%- endif %}
5656 UsersRestServletV2,
5757 WhoisRestServlet,
5858 )
59 from synapse.types import RoomStreamToken
5960 from synapse.util.versionstring import get_version_string
6061
6162 logger = logging.getLogger(__name__)
108109 if event.room_id != room_id:
109110 raise SynapseError(400, "Event is for wrong room.")
110111
111 room_token = await self.store.get_topological_token_for_event(event_id)
112 room_token = RoomStreamToken(
113 event.depth, event.internal_metadata.stream_ordering
114 )
112115 token = await room_token.to_string(self.store)
113116
114117 logger.info("[purge] purging up to token %s (event_id %s)", token, event_id)
137137 def __init__(self, hs):
138138 self.store = hs.get_datastore()
139139 self.auth = hs.get_auth()
140 self.admin_handler = hs.get_handlers().admin_handler
140 self.admin_handler = hs.get_admin_handler()
141141
142142 async def on_GET(self, request):
143143 requester = await self.auth.get_user_by_req(request)
272272 self.hs = hs
273273 self.auth = hs.get_auth()
274274 self.room_member_handler = hs.get_room_member_handler()
275 self.admin_handler = hs.get_handlers().admin_handler
275 self.admin_handler = hs.get_admin_handler()
276276 self.state_handler = hs.get_state_handler()
277277
278278 async def on_POST(self, request, room_identifier):
4444 self.hs = hs
4545 self.store = hs.get_datastore()
4646 self.auth = hs.get_auth()
47 self.admin_handler = hs.get_handlers().admin_handler
47 self.admin_handler = hs.get_admin_handler()
4848
4949 async def on_GET(self, request, user_id):
5050 target_user = UserID.from_string(user_id)
8181 self.hs = hs
8282 self.store = hs.get_datastore()
8383 self.auth = hs.get_auth()
84 self.admin_handler = hs.get_handlers().admin_handler
84 self.admin_handler = hs.get_admin_handler()
8585
8686 async def on_GET(self, request):
8787 await assert_requester_is_admin(self.auth, request)
134134 def __init__(self, hs):
135135 self.hs = hs
136136 self.auth = hs.get_auth()
137 self.admin_handler = hs.get_handlers().admin_handler
137 self.admin_handler = hs.get_admin_handler()
138138 self.store = hs.get_datastore()
139139 self.auth_handler = hs.get_auth_handler()
140140 self.profile_handler = hs.get_profile_handler()
447447 def __init__(self, hs):
448448 self.hs = hs
449449 self.auth = hs.get_auth()
450 self.handlers = hs.get_handlers()
450 self.admin_handler = hs.get_admin_handler()
451451
452452 async def on_GET(self, request, user_id):
453453 target_user = UserID.from_string(user_id)
460460 if not self.hs.is_mine(target_user):
461461 raise SynapseError(400, "Can only whois a local user")
462462
463 ret = await self.handlers.admin_handler.get_whois(target_user)
463 ret = await self.admin_handler.get_whois(target_user)
464464
465465 return 200, ret
466466
590590 self.hs = hs
591591 self.store = hs.get_datastore()
592592 self.auth = hs.get_auth()
593 self.handlers = hs.get_handlers()
594593
595594 async def on_GET(self, request, target_user_id):
596595 """Get request to search user table for specific users according to
611610 term = parse_string(request, "term", required=True)
612611 logger.info("term: %s ", term)
613612
614 ret = await self.handlers.store.search_users(term)
613 ret = await self.store.search_users(term)
615614 return 200, ret
616615
617616
4141 def __init__(self, hs):
4242 super().__init__()
4343 self.store = hs.get_datastore()
44 self.handlers = hs.get_handlers()
44 self.directory_handler = hs.get_directory_handler()
4545 self.auth = hs.get_auth()
4646
4747 async def on_GET(self, request, room_alias):
4848 room_alias = RoomAlias.from_string(room_alias)
4949
50 dir_handler = self.handlers.directory_handler
51 res = await dir_handler.get_association(room_alias)
50 res = await self.directory_handler.get_association(room_alias)
5251
5352 return 200, res
5453
7877
7978 requester = await self.auth.get_user_by_req(request)
8079
81 await self.handlers.directory_handler.create_association(
80 await self.directory_handler.create_association(
8281 requester, room_alias, room_id, servers
8382 )
8483
8584 return 200, {}
8685
8786 async def on_DELETE(self, request, room_alias):
88 dir_handler = self.handlers.directory_handler
89
9087 try:
9188 service = self.auth.get_appservice_by_req(request)
9289 room_alias = RoomAlias.from_string(room_alias)
93 await dir_handler.delete_appservice_association(service, room_alias)
90 await self.directory_handler.delete_appservice_association(
91 service, room_alias
92 )
9493 logger.info(
9594 "Application service at %s deleted alias %s",
9695 service.url,
106105
107106 room_alias = RoomAlias.from_string(room_alias)
108107
109 await dir_handler.delete_association(requester, room_alias)
108 await self.directory_handler.delete_association(requester, room_alias)
110109
111110 logger.info(
112111 "User %s deleted alias %s", user.to_string(), room_alias.to_string()
121120 def __init__(self, hs):
122121 super().__init__()
123122 self.store = hs.get_datastore()
124 self.handlers = hs.get_handlers()
123 self.directory_handler = hs.get_directory_handler()
125124 self.auth = hs.get_auth()
126125
127126 async def on_GET(self, request, room_id):
137136 content = parse_json_object_from_request(request)
138137 visibility = content.get("visibility", "public")
139138
140 await self.handlers.directory_handler.edit_published_room_list(
139 await self.directory_handler.edit_published_room_list(
141140 requester, room_id, visibility
142141 )
143142
146145 async def on_DELETE(self, request, room_id):
147146 requester = await self.auth.get_user_by_req(request)
148147
149 await self.handlers.directory_handler.edit_published_room_list(
148 await self.directory_handler.edit_published_room_list(
150149 requester, room_id, "private"
151150 )
152151
161160 def __init__(self, hs):
162161 super().__init__()
163162 self.store = hs.get_datastore()
164 self.handlers = hs.get_handlers()
163 self.directory_handler = hs.get_directory_handler()
165164 self.auth = hs.get_auth()
166165
167166 def on_PUT(self, request, network_id, room_id):
179178 403, "Only appservices can edit the appservice published room list"
180179 )
181180
182 await self.handlers.directory_handler.edit_published_appservice_room_list(
181 await self.directory_handler.edit_published_appservice_room_list(
183182 requester.app_service.id, network_id, room_id, visibility
184183 )
185184
6666
6767 self.auth_handler = self.hs.get_auth_handler()
6868 self.registration_handler = hs.get_registration_handler()
69 self.handlers = hs.get_handlers()
7069 self._well_known_builder = WellKnownBuilder(hs)
7170 self._address_ratelimiter = Ratelimiter(
7271 clock=hs.get_clock(),
109108 flows.extend(
110109 ({"type": t} for t in self.auth_handler.get_supported_login_types())
111110 )
111
112 flows.append({"type": LoginRestServlet.APPSERVICE_TYPE})
112113
113114 return 200, {"flows": flows}
114115
5858 try:
5959 new_name = content["displayname"]
6060 except Exception:
61 return 400, "Unable to parse name"
61 raise SynapseError(
62 code=400, msg="Unable to parse name", errcode=Codes.BAD_JSON,
63 )
6264
6365 await self.profile_handler.set_displayname(user, requester, new_name, is_admin)
6466
111111 class RoomStateEventRestServlet(TransactionRestServlet):
112112 def __init__(self, hs):
113113 super().__init__(hs)
114 self.handlers = hs.get_handlers()
115114 self.event_creation_handler = hs.get_event_creation_handler()
116115 self.room_member_handler = hs.get_room_member_handler()
117116 self.message_handler = hs.get_message_handler()
797796 class RoomRedactEventRestServlet(TransactionRestServlet):
798797 def __init__(self, hs):
799798 super().__init__(hs)
800 self.handlers = hs.get_handlers()
801799 self.event_creation_handler = hs.get_event_creation_handler()
802800 self.auth = hs.get_auth()
803801
902900 def __init__(self, hs: "synapse.server.HomeServer"):
903901 super().__init__()
904902 self.auth = hs.get_auth()
905 self.directory_handler = hs.get_handlers().directory_handler
903 self.directory_handler = hs.get_directory_handler()
906904
907905 async def on_GET(self, request, room_id):
908906 requester = await self.auth.get_user_by_req(request)
919917
920918 def __init__(self, hs):
921919 super().__init__()
922 self.handlers = hs.get_handlers()
920 self.search_handler = hs.get_search_handler()
923921 self.auth = hs.get_auth()
924922
925923 async def on_POST(self, request):
928926 content = parse_json_object_from_request(request)
929927
930928 batch = parse_string(request, "next_batch")
931 results = await self.handlers.search_handler.search(
932 requester.user, content, batch
933 )
929 results = await self.search_handler.search(requester.user, content, batch)
934930
935931 return 200, results
936932
5555 self.hs = hs
5656 self.datastore = hs.get_datastore()
5757 self.config = hs.config
58 self.identity_handler = hs.get_handlers().identity_handler
58 self.identity_handler = hs.get_identity_handler()
5959
6060 if self.config.threepid_behaviour_email == ThreepidBehaviour.LOCAL:
6161 self.mailer = Mailer(
326326 super().__init__()
327327 self.hs = hs
328328 self.config = hs.config
329 self.identity_handler = hs.get_handlers().identity_handler
329 self.identity_handler = hs.get_identity_handler()
330330 self.store = self.hs.get_datastore()
331331
332332 if self.config.threepid_behaviour_email == ThreepidBehaviour.LOCAL:
423423 self.hs = hs
424424 super().__init__()
425425 self.store = self.hs.get_datastore()
426 self.identity_handler = hs.get_handlers().identity_handler
426 self.identity_handler = hs.get_identity_handler()
427427
428428 async def on_POST(self, request):
429429 body = parse_json_object_from_request(request)
573573 self.config = hs.config
574574 self.clock = hs.get_clock()
575575 self.store = hs.get_datastore()
576 self.identity_handler = hs.get_handlers().identity_handler
576 self.identity_handler = hs.get_identity_handler()
577577
578578 async def on_POST(self, request):
579579 if not self.config.account_threepid_delegate_msisdn:
603603 def __init__(self, hs):
604604 super().__init__()
605605 self.hs = hs
606 self.identity_handler = hs.get_handlers().identity_handler
606 self.identity_handler = hs.get_identity_handler()
607607 self.auth = hs.get_auth()
608608 self.auth_handler = hs.get_auth_handler()
609609 self.datastore = self.hs.get_datastore()
659659 def __init__(self, hs):
660660 super().__init__()
661661 self.hs = hs
662 self.identity_handler = hs.get_handlers().identity_handler
662 self.identity_handler = hs.get_identity_handler()
663663 self.auth = hs.get_auth()
664664 self.auth_handler = hs.get_auth_handler()
665665
710710 def __init__(self, hs):
711711 super().__init__()
712712 self.hs = hs
713 self.identity_handler = hs.get_handlers().identity_handler
713 self.identity_handler = hs.get_identity_handler()
714714 self.auth = hs.get_auth()
715715
716716 async def on_POST(self, request):
739739 def __init__(self, hs):
740740 super().__init__()
741741 self.hs = hs
742 self.identity_handler = hs.get_handlers().identity_handler
742 self.identity_handler = hs.get_identity_handler()
743743 self.auth = hs.get_auth()
744744 self.datastore = self.hs.get_datastore()
745745
00 # -*- coding: utf-8 -*-
11 # Copyright 2015, 2016 OpenMarket Ltd
2 # Copyright 2020 The Matrix.org Foundation C.I.C.
23 #
34 # Licensed under the Apache License, Version 2.0 (the "License");
45 # you may not use this file except in compliance with the License.
2021 assert_params_in_dict,
2122 parse_json_object_from_request,
2223 )
24 from synapse.http.site import SynapseRequest
2325
2426 from ._base import client_patterns, interactive_auth_handler
2527
150152 return 200, {}
151153
152154
155 class DehydratedDeviceServlet(RestServlet):
156 """Retrieve or store a dehydrated device.
157
158 GET /org.matrix.msc2697.v2/dehydrated_device
159
160 HTTP/1.1 200 OK
161 Content-Type: application/json
162
163 {
164 "device_id": "dehydrated_device_id",
165 "device_data": {
166 "algorithm": "org.matrix.msc2697.v1.dehydration.v1.olm",
167 "account": "dehydrated_device"
168 }
169 }
170
171 PUT /org.matrix.msc2697/dehydrated_device
172 Content-Type: application/json
173
174 {
175 "device_data": {
176 "algorithm": "org.matrix.msc2697.v1.dehydration.v1.olm",
177 "account": "dehydrated_device"
178 }
179 }
180
181 HTTP/1.1 200 OK
182 Content-Type: application/json
183
184 {
185 "device_id": "dehydrated_device_id"
186 }
187
188 """
189
190 PATTERNS = client_patterns("/org.matrix.msc2697.v2/dehydrated_device", releases=())
191
192 def __init__(self, hs):
193 super().__init__()
194 self.hs = hs
195 self.auth = hs.get_auth()
196 self.device_handler = hs.get_device_handler()
197
198 async def on_GET(self, request: SynapseRequest):
199 requester = await self.auth.get_user_by_req(request)
200 dehydrated_device = await self.device_handler.get_dehydrated_device(
201 requester.user.to_string()
202 )
203 if dehydrated_device is not None:
204 (device_id, device_data) = dehydrated_device
205 result = {"device_id": device_id, "device_data": device_data}
206 return (200, result)
207 else:
208 raise errors.NotFoundError("No dehydrated device available")
209
210 async def on_PUT(self, request: SynapseRequest):
211 submission = parse_json_object_from_request(request)
212 requester = await self.auth.get_user_by_req(request)
213
214 if "device_data" not in submission:
215 raise errors.SynapseError(
216 400, "device_data missing", errcode=errors.Codes.MISSING_PARAM,
217 )
218 elif not isinstance(submission["device_data"], dict):
219 raise errors.SynapseError(
220 400,
221 "device_data must be an object",
222 errcode=errors.Codes.INVALID_PARAM,
223 )
224
225 device_id = await self.device_handler.store_dehydrated_device(
226 requester.user.to_string(),
227 submission["device_data"],
228 submission.get("initial_device_display_name", None),
229 )
230 return 200, {"device_id": device_id}
231
232
233 class ClaimDehydratedDeviceServlet(RestServlet):
234 """Claim a dehydrated device.
235
236 POST /org.matrix.msc2697.v2/dehydrated_device/claim
237 Content-Type: application/json
238
239 {
240 "device_id": "dehydrated_device_id"
241 }
242
243 HTTP/1.1 200 OK
244 Content-Type: application/json
245
246 {
247 "success": true,
248 }
249
250 """
251
252 PATTERNS = client_patterns(
253 "/org.matrix.msc2697.v2/dehydrated_device/claim", releases=()
254 )
255
256 def __init__(self, hs):
257 super().__init__()
258 self.hs = hs
259 self.auth = hs.get_auth()
260 self.device_handler = hs.get_device_handler()
261
262 async def on_POST(self, request: SynapseRequest):
263 requester = await self.auth.get_user_by_req(request)
264
265 submission = parse_json_object_from_request(request)
266
267 if "device_id" not in submission:
268 raise errors.SynapseError(
269 400, "device_id missing", errcode=errors.Codes.MISSING_PARAM,
270 )
271 elif not isinstance(submission["device_id"], str):
272 raise errors.SynapseError(
273 400, "device_id must be a string", errcode=errors.Codes.INVALID_PARAM,
274 )
275
276 result = await self.device_handler.rehydrate_device(
277 requester.user.to_string(),
278 self.auth.get_access_token_from_request(request),
279 submission["device_id"],
280 )
281
282 return (200, result)
283
284
153285 def register_servlets(hs, http_server):
154286 DeleteDevicesRestServlet(hs).register(http_server)
155287 DevicesRestServlet(hs).register(http_server)
156288 DeviceRestServlet(hs).register(http_server)
289 DehydratedDeviceServlet(hs).register(http_server)
290 ClaimDehydratedDeviceServlet(hs).register(http_server)
00 # -*- coding: utf-8 -*-
11 # Copyright 2015, 2016 OpenMarket Ltd
22 # Copyright 2019 New Vector Ltd
3 # Copyright 2020 The Matrix.org Foundation C.I.C.
34 #
45 # Licensed under the Apache License, Version 2.0 (the "License");
56 # you may not use this file except in compliance with the License.
6667 super().__init__()
6768 self.auth = hs.get_auth()
6869 self.e2e_keys_handler = hs.get_e2e_keys_handler()
70 self.device_handler = hs.get_device_handler()
6971
7072 @trace(opname="upload_keys")
7173 async def on_POST(self, request, device_id):
7476 body = parse_json_object_from_request(request)
7577
7678 if device_id is not None:
77 # passing the device_id here is deprecated; however, we allow it
78 # for now for compatibility with older clients.
79 # Providing the device_id should only be done for setting keys
80 # for dehydrated devices; however, we allow it for any device for
81 # compatibility with older clients.
7982 if requester.device_id is not None and device_id != requester.device_id:
80 set_tag("error", True)
81 log_kv(
82 {
83 "message": "Client uploading keys for a different device",
84 "logged_in_id": requester.device_id,
85 "key_being_uploaded": device_id,
86 }
83 dehydrated_device = await self.device_handler.get_dehydrated_device(
84 user_id
8785 )
88 logger.warning(
89 "Client uploading keys for a different device "
90 "(logged in as %s, uploading for %s)",
91 requester.device_id,
92 device_id,
93 )
86 if dehydrated_device is not None and device_id != dehydrated_device[0]:
87 set_tag("error", True)
88 log_kv(
89 {
90 "message": "Client uploading keys for a different device",
91 "logged_in_id": requester.device_id,
92 "key_being_uploaded": device_id,
93 }
94 )
95 logger.warning(
96 "Client uploading keys for a different device "
97 "(logged in as %s, uploading for %s)",
98 requester.device_id,
99 device_id,
100 )
94101 else:
95102 device_id = requester.device_id
96103
7777 """
7878 super().__init__()
7979 self.hs = hs
80 self.identity_handler = hs.get_handlers().identity_handler
80 self.identity_handler = hs.get_identity_handler()
8181 self.config = hs.config
8282
8383 if self.hs.config.threepid_behaviour_email == ThreepidBehaviour.LOCAL:
175175 """
176176 super().__init__()
177177 self.hs = hs
178 self.identity_handler = hs.get_handlers().identity_handler
178 self.identity_handler = hs.get_identity_handler()
179179
180180 async def on_POST(self, request):
181181 body = parse_json_object_from_request(request)
369369 self.store = hs.get_datastore()
370370 self.auth_handler = hs.get_auth_handler()
371371 self.registration_handler = hs.get_registration_handler()
372 self.identity_handler = hs.get_handlers().identity_handler
372 self.identity_handler = hs.get_identity_handler()
373373 self.room_member_handler = hs.get_room_member_handler()
374374 self.macaroon_gen = hs.get_macaroon_generator()
375375 self.ratelimiter = hs.get_registration_ratelimiter()
235235 "leave": sync_result.groups.leave,
236236 },
237237 "device_one_time_keys_count": sync_result.device_one_time_keys_count,
238 "org.matrix.msc2732.device_unused_fallback_key_types": sync_result.device_unused_fallback_key_types,
238239 "next_batch": await sync_result.next_batch.to_string(self.store),
239240 }
240241
212212 file_size (int|None): Size in bytes of the media. If not known it should be None
213213 upload_name (str|None): The name of the requested file, if any.
214214 """
215 if request._disconnected:
216 logger.warning(
217 "Not sending response to request %s, already disconnected.", request
218 )
219 return
220
215221 if not responder:
216222 respond_404(request)
217223 return
5353 from synapse.federation.transport.client import TransportLayerClient
5454 from synapse.groups.attestations import GroupAttestationSigning, GroupAttestionRenewer
5555 from synapse.groups.groups_server import GroupsServerHandler, GroupsServerWorkerHandler
56 from synapse.handlers import Handlers
5756 from synapse.handlers.account_validity import AccountValidityHandler
5857 from synapse.handlers.acme import AcmeHandler
58 from synapse.handlers.admin import AdminHandler
5959 from synapse.handlers.appservice import ApplicationServicesHandler
6060 from synapse.handlers.auth import AuthHandler, MacaroonGenerator
6161 from synapse.handlers.cas_handler import CasHandler
6262 from synapse.handlers.deactivate_account import DeactivateAccountHandler
6363 from synapse.handlers.device import DeviceHandler, DeviceWorkerHandler
6464 from synapse.handlers.devicemessage import DeviceMessageHandler
65 from synapse.handlers.directory import DirectoryHandler
6566 from synapse.handlers.e2e_keys import E2eKeysHandler
6667 from synapse.handlers.e2e_room_keys import E2eRoomKeysHandler
6768 from synapse.handlers.events import EventHandler, EventStreamHandler
69 from synapse.handlers.federation import FederationHandler
6870 from synapse.handlers.groups_local import GroupsLocalHandler, GroupsLocalWorkerHandler
71 from synapse.handlers.identity import IdentityHandler
6972 from synapse.handlers.initial_sync import InitialSyncHandler
7073 from synapse.handlers.message import EventCreationHandler, MessageHandler
7174 from synapse.handlers.pagination import PaginationHandler
7275 from synapse.handlers.password_policy import PasswordPolicyHandler
7376 from synapse.handlers.presence import PresenceHandler
74 from synapse.handlers.profile import BaseProfileHandler, MasterProfileHandler
77 from synapse.handlers.profile import ProfileHandler
7578 from synapse.handlers.read_marker import ReadMarkerHandler
7679 from synapse.handlers.receipts import ReceiptsHandler
7780 from synapse.handlers.register import RegistrationHandler
8386 from synapse.handlers.room_list import RoomListHandler
8487 from synapse.handlers.room_member import RoomMemberMasterHandler
8588 from synapse.handlers.room_member_worker import RoomMemberWorkerHandler
89 from synapse.handlers.search import SearchHandler
8690 from synapse.handlers.set_password import SetPasswordHandler
8791 from synapse.handlers.stats import StatsHandler
8892 from synapse.handlers.sync import SyncHandler
9094 from synapse.handlers.user_directory import UserDirectoryHandler
9195 from synapse.http.client import InsecureInterceptableContextFactory, SimpleHttpClient
9296 from synapse.http.matrixfederationclient import MatrixFederationHttpClient
97 from synapse.module_api import ModuleApi
9398 from synapse.notifier import Notifier
9499 from synapse.push.action_generator import ActionGenerator
95100 from synapse.push.pusherpool import PusherPool
184189 we are listening on to provide HTTP services.
185190 """
186191
187 REQUIRED_ON_MASTER_STARTUP = ["user_directory_handler", "stats_handler"]
192 REQUIRED_ON_BACKGROUND_TASK_STARTUP = [
193 "account_validity",
194 "auth",
195 "deactivate_account",
196 "message",
197 "pagination",
198 "profile",
199 "stats",
200 ]
188201
189202 # This is overridden in derived application classes
190203 # (such as synapse.app.homeserver.SynapseHomeServer) and gives the class to be
191204 # instantiated during setup() for future return by get_datastore()
192205 DATASTORE_CLASS = abc.abstractproperty()
193206
194 def __init__(self, hostname: str, config: HomeServerConfig, reactor=None, **kwargs):
207 def __init__(
208 self,
209 hostname: str,
210 config: HomeServerConfig,
211 reactor=None,
212 version_string="Synapse",
213 ):
195214 """
196215 Args:
197216 hostname : The hostname for the server.
222241 burst_count=config.rc_registration.burst_count,
223242 )
224243
244 self.version_string = version_string
245
225246 self.datastores = None # type: Optional[Databases]
226
227 # Other kwargs are explicit dependencies
228 for depname in kwargs:
229 setattr(self, depname, kwargs[depname])
230247
231248 def get_instance_id(self) -> str:
232249 """A unique ID for this synapse process instance.
250267 self.datastores = Databases(self.DATASTORE_CLASS, self)
251268 logger.info("Finished setting up.")
252269
253 def setup_master(self) -> None:
270 # Register background tasks required by this server. This must be done
271 # somewhat manually due to the background tasks not being registered
272 # unless handlers are instantiated.
273 if self.config.run_background_tasks:
274 self.setup_background_tasks()
275
276 def setup_background_tasks(self) -> None:
254277 """
255278 Some handlers have side effects on instantiation (like registering
256279 background updates). This function causes them to be fetched, and
257280 therefore instantiated, to run those side effects.
258281 """
259 for i in self.REQUIRED_ON_MASTER_STARTUP:
260 getattr(self, "get_" + i)()
282 for i in self.REQUIRED_ON_BACKGROUND_TASK_STARTUP:
283 getattr(self, "get_" + i + "_handler")()
261284
262285 def get_reactor(self) -> twisted.internet.base.ReactorBase:
263286 """
306329 @cache_in_self
307330 def get_federation_server(self) -> FederationServer:
308331 return FederationServer(self)
309
310 @cache_in_self
311 def get_handlers(self) -> Handlers:
312 return Handlers(self)
313332
314333 @cache_in_self
315334 def get_notifier(self) -> Notifier:
398417 return DeviceMessageHandler(self)
399418
400419 @cache_in_self
420 def get_directory_handler(self) -> DirectoryHandler:
421 return DirectoryHandler(self)
422
423 @cache_in_self
401424 def get_e2e_keys_handler(self) -> E2eKeysHandler:
402425 return E2eKeysHandler(self)
403426
410433 return AcmeHandler(self)
411434
412435 @cache_in_self
436 def get_admin_handler(self) -> AdminHandler:
437 return AdminHandler(self)
438
439 @cache_in_self
413440 def get_application_service_api(self) -> ApplicationServiceApi:
414441 return ApplicationServiceApi(self)
415442
430457 return EventStreamHandler(self)
431458
432459 @cache_in_self
460 def get_federation_handler(self) -> FederationHandler:
461 return FederationHandler(self)
462
463 @cache_in_self
464 def get_identity_handler(self) -> IdentityHandler:
465 return IdentityHandler(self)
466
467 @cache_in_self
433468 def get_initial_sync_handler(self) -> InitialSyncHandler:
434469 return InitialSyncHandler(self)
435470
436471 @cache_in_self
437472 def get_profile_handler(self):
438 if self.config.worker_app:
439 return BaseProfileHandler(self)
440 else:
441 return MasterProfileHandler(self)
473 return ProfileHandler(self)
442474
443475 @cache_in_self
444476 def get_event_creation_handler(self) -> EventCreationHandler:
447479 @cache_in_self
448480 def get_deactivate_account_handler(self) -> DeactivateAccountHandler:
449481 return DeactivateAccountHandler(self)
482
483 @cache_in_self
484 def get_search_handler(self) -> SearchHandler:
485 return SearchHandler(self)
450486
451487 @cache_in_self
452488 def get_set_password_handler(self) -> SetPasswordHandler:
646682 def get_federation_ratelimiter(self) -> FederationRateLimiter:
647683 return FederationRateLimiter(self.clock, config=self.config.rc_federation)
648684
685 @cache_in_self
686 def get_module_api(self) -> ModuleApi:
687 return ModuleApi(self, self.get_auth_handler())
688
649689 async def remove_pusher(self, app_id: str, push_key: str, user_id: str):
650690 return await self.get_pusherpool().remove_pusher(app_id, push_key, user_id)
651691
1111 # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
1212 # See the License for the specific language governing permissions and
1313 # limitations under the License.
14 import logging
1514 from enum import Enum
16
17 from twisted.internet import defer
18
19 from synapse.storage.state import StateFilter
20
21 MYPY = False
22 if MYPY:
23 import synapse.server
24
25 logger = logging.getLogger(__name__)
2615
2716
2817 class RegistrationBehaviour(Enum):
3322 ALLOW = "allow"
3423 SHADOW_BAN = "shadow_ban"
3524 DENY = "deny"
36
37
38 class SpamCheckerApi:
39 """A proxy object that gets passed to spam checkers so they can get
40 access to rooms and other relevant information.
41 """
42
43 def __init__(self, hs: "synapse.server.HomeServer"):
44 self.hs = hs
45
46 self._store = hs.get_datastore()
47
48 @defer.inlineCallbacks
49 def get_state_events_in_room(self, room_id: str, types: tuple) -> defer.Deferred:
50 """Gets state events for the given room.
51
52 Args:
53 room_id: The room ID to get state events in.
54 types: The event type and state key (using None
55 to represent 'any') of the room state to acquire.
56
57 Returns:
58 twisted.internet.defer.Deferred[list(synapse.events.FrozenEvent)]:
59 The filtered state events in the room.
60 """
61 state_ids = yield defer.ensureDeferred(
62 self._store.get_filtered_current_state_ids(
63 room_id=room_id, state_filter=StateFilter.from_types(types)
64 )
65 )
66 state = yield defer.ensureDeferred(self._store.get_events(state_ids.values()))
67 return state.values()
737737
738738 # failing that, look for the closest match.
739739 prev_group = None
740 delta_ids = None
740 delta_ids = None # type: Optional[StateMap[str]]
741741
742742 for old_group, old_state in state_groups_ids.items():
743743 n_delta_ids = {k: v for k, v in new_state.items() if old_state.get(k) != v}
7575 """
7676
7777 try:
78 if key is None:
79 getattr(self, cache_name).invalidate_all()
80 else:
81 getattr(self, cache_name).invalidate(tuple(key))
78 cache = getattr(self, cache_name)
8279 except AttributeError:
8380 # We probably haven't pulled in the cache in this worker,
8481 # which is fine.
85 pass
82 return
83
84 if key is None:
85 cache.invalidate_all()
86 else:
87 cache.invalidate(tuple(key))
8688
8789
8890 def db_to_json(db_content):
3131 overload,
3232 )
3333
34 import attr
3435 from prometheus_client import Histogram
3536 from typing_extensions import Literal
3637
8990 return adbapi.ConnectionPool(
9091 db_config.config["name"],
9192 cp_reactor=reactor,
92 cp_openfun=engine.on_new_connection,
93 cp_openfun=lambda conn: engine.on_new_connection(
94 LoggingDatabaseConnection(conn, engine, "on_new_connection")
95 ),
9396 **db_config.config.get("args", {})
9497 )
9598
9699
97100 def make_conn(
98 db_config: DatabaseConnectionConfig, engine: BaseDatabaseEngine
101 db_config: DatabaseConnectionConfig,
102 engine: BaseDatabaseEngine,
103 default_txn_name: str,
99104 ) -> Connection:
100105 """Make a new connection to the database and return it.
101106
108113 for k, v in db_config.config.get("args", {}).items()
109114 if not k.startswith("cp_")
110115 }
111 db_conn = engine.module.connect(**db_params)
116 native_db_conn = engine.module.connect(**db_params)
117 db_conn = LoggingDatabaseConnection(native_db_conn, engine, default_txn_name)
118
112119 engine.on_new_connection(db_conn)
113120 return db_conn
121
122
123 @attr.s(slots=True)
124 class LoggingDatabaseConnection:
125 """A wrapper around a database connection that returns `LoggingTransaction`
126 as its cursor class.
127
128 This is mainly used on startup to ensure that queries get logged correctly
129 """
130
131 conn = attr.ib(type=Connection)
132 engine = attr.ib(type=BaseDatabaseEngine)
133 default_txn_name = attr.ib(type=str)
134
135 def cursor(
136 self, *, txn_name=None, after_callbacks=None, exception_callbacks=None
137 ) -> "LoggingTransaction":
138 if not txn_name:
139 txn_name = self.default_txn_name
140
141 return LoggingTransaction(
142 self.conn.cursor(),
143 name=txn_name,
144 database_engine=self.engine,
145 after_callbacks=after_callbacks,
146 exception_callbacks=exception_callbacks,
147 )
148
149 def close(self) -> None:
150 self.conn.close()
151
152 def commit(self) -> None:
153 self.conn.commit()
154
155 def rollback(self, *args, **kwargs) -> None:
156 self.conn.rollback(*args, **kwargs)
157
158 def __enter__(self) -> "Connection":
159 self.conn.__enter__()
160 return self
161
162 def __exit__(self, exc_type, exc_value, traceback) -> Optional[bool]:
163 return self.conn.__exit__(exc_type, exc_value, traceback)
164
165 # Proxy through any unknown lookups to the DB conn class.
166 def __getattr__(self, name):
167 return getattr(self.conn, name)
114168
115169
116170 # The type of entry which goes on our after_callbacks and exception_callbacks lists.
246300 def close(self) -> None:
247301 self.txn.close()
248302
303 def __enter__(self) -> "LoggingTransaction":
304 return self
305
306 def __exit__(self, exc_type, exc_value, traceback):
307 self.close()
308
249309
250310 class PerformanceCounters:
251311 def __init__(self):
394454
395455 def new_transaction(
396456 self,
397 conn: Connection,
457 conn: LoggingDatabaseConnection,
398458 desc: str,
399459 after_callbacks: List[_CallbackListEntry],
400460 exception_callbacks: List[_CallbackListEntry],
435495 i = 0
436496 N = 5
437497 while True:
438 cursor = LoggingTransaction(
439 conn.cursor(),
440 name,
441 self.engine,
442 after_callbacks,
443 exception_callbacks,
498 cursor = conn.cursor(
499 txn_name=name,
500 after_callbacks=after_callbacks,
501 exception_callbacks=exception_callbacks,
444502 )
445503 try:
446504 r = func(cursor, *args, **kwargs)
637695 if db_autocommit:
638696 self.engine.attempt_to_set_autocommit(conn, True)
639697
640 return func(conn, *args, **kwargs)
698 db_conn = LoggingDatabaseConnection(
699 conn, self.engine, "runWithConnection"
700 )
701 return func(db_conn, *args, **kwargs)
641702 finally:
642703 if db_autocommit:
643704 self.engine.attempt_to_set_autocommit(conn, False)
831892 attempts = 0
832893 while True:
833894 try:
895 # We can autocommit if we are going to use native upserts
896 autocommit = (
897 self.engine.can_native_upsert
898 and table not in self._unsafe_to_upsert_tables
899 )
900
834901 return await self.runInteraction(
835902 desc,
836903 self.simple_upsert_txn,
839906 values,
840907 insertion_values,
841908 lock=lock,
909 db_autocommit=autocommit,
842910 )
843911 except self.engine.module.IntegrityError as e:
844912 attempts += 1
10011069 )
10021070 txn.execute(sql, list(allvalues.values()))
10031071
1072 async def simple_upsert_many(
1073 self,
1074 table: str,
1075 key_names: Collection[str],
1076 key_values: Collection[Iterable[Any]],
1077 value_names: Collection[str],
1078 value_values: Iterable[Iterable[Any]],
1079 desc: str,
1080 ) -> None:
1081 """
1082 Upsert, many times.
1083
1084 Args:
1085 table: The table to upsert into
1086 key_names: The key column names.
1087 key_values: A list of each row's key column values.
1088 value_names: The value column names
1089 value_values: A list of each row's value column values.
1090 Ignored if value_names is empty.
1091 """
1092
1093 # We can autocommit if we are going to use native upserts
1094 autocommit = (
1095 self.engine.can_native_upsert and table not in self._unsafe_to_upsert_tables
1096 )
1097
1098 return await self.runInteraction(
1099 desc,
1100 self.simple_upsert_many_txn,
1101 table,
1102 key_names,
1103 key_values,
1104 value_names,
1105 value_values,
1106 db_autocommit=autocommit,
1107 )
1108
10041109 def simple_upsert_many_txn(
10051110 self,
10061111 txn: LoggingTransaction,
11521257 desc: description of the transaction, for logging and metrics
11531258 """
11541259 return await self.runInteraction(
1155 desc, self.simple_select_one_txn, table, keyvalues, retcols, allow_none
1260 desc,
1261 self.simple_select_one_txn,
1262 table,
1263 keyvalues,
1264 retcols,
1265 allow_none,
1266 db_autocommit=True,
11561267 )
11571268
11581269 @overload
12031314 keyvalues,
12041315 retcol,
12051316 allow_none=allow_none,
1317 db_autocommit=True,
12061318 )
12071319
12081320 @overload
12841396 Results in a list
12851397 """
12861398 return await self.runInteraction(
1287 desc, self.simple_select_onecol_txn, table, keyvalues, retcol
1399 desc,
1400 self.simple_select_onecol_txn,
1401 table,
1402 keyvalues,
1403 retcol,
1404 db_autocommit=True,
12881405 )
12891406
12901407 async def simple_select_list(
13091426 A list of dictionaries.
13101427 """
13111428 return await self.runInteraction(
1312 desc, self.simple_select_list_txn, table, keyvalues, retcols
1429 desc,
1430 self.simple_select_list_txn,
1431 table,
1432 keyvalues,
1433 retcols,
1434 db_autocommit=True,
13131435 )
13141436
13151437 @classmethod
13881510 chunk,
13891511 keyvalues,
13901512 retcols,
1513 db_autocommit=True,
13911514 )
13921515
13931516 results.extend(rows)
14861609 desc: description of the transaction, for logging and metrics
14871610 """
14881611 await self.runInteraction(
1489 desc, self.simple_update_one_txn, table, keyvalues, updatevalues
1612 desc,
1613 self.simple_update_one_txn,
1614 table,
1615 keyvalues,
1616 updatevalues,
1617 db_autocommit=True,
14901618 )
14911619
14921620 @classmethod
15451673 keyvalues: dict of column names and values to select the row with
15461674 desc: description of the transaction, for logging and metrics
15471675 """
1548 await self.runInteraction(desc, self.simple_delete_one_txn, table, keyvalues)
1676 await self.runInteraction(
1677 desc, self.simple_delete_one_txn, table, keyvalues, db_autocommit=True,
1678 )
15491679
15501680 @staticmethod
15511681 def simple_delete_one_txn(
15841714 Returns:
15851715 The number of deleted rows.
15861716 """
1587 return await self.runInteraction(desc, self.simple_delete_txn, table, keyvalues)
1717 return await self.runInteraction(
1718 desc, self.simple_delete_txn, table, keyvalues, db_autocommit=True
1719 )
15881720
15891721 @staticmethod
15901722 def simple_delete_txn(
16321764 Number rows deleted
16331765 """
16341766 return await self.runInteraction(
1635 desc, self.simple_delete_many_txn, table, column, iterable, keyvalues
1767 desc,
1768 self.simple_delete_many_txn,
1769 table,
1770 column,
1771 iterable,
1772 keyvalues,
1773 db_autocommit=True,
16361774 )
16371775
16381776 @staticmethod
16771815
16781816 def get_cache_dict(
16791817 self,
1680 db_conn: Connection,
1818 db_conn: LoggingDatabaseConnection,
16811819 table: str,
16821820 entity_column: str,
16831821 stream_column: str,
16981836 "limit": limit,
16991837 }
17001838
1701 sql = self.engine.convert_param_style(sql)
1702
1703 txn = db_conn.cursor()
1839 txn = db_conn.cursor(txn_name="get_cache_dict")
17041840 txn.execute(sql, (int(max_value),))
17051841
17061842 cache = {row[0]: int(row[1]) for row in txn}
18001936 """
18011937
18021938 return await self.runInteraction(
1803 desc, self.simple_search_list_txn, table, term, col, retcols
1939 desc,
1940 self.simple_search_list_txn,
1941 table,
1942 term,
1943 col,
1944 retcols,
1945 db_autocommit=True,
18041946 )
18051947
18061948 @classmethod
4545 db_name = database_config.name
4646 engine = create_engine(database_config.config)
4747
48 with make_conn(database_config, engine) as db_conn:
48 with make_conn(database_config, engine, "startup") as db_conn:
4949 logger.info("[database config %r]: Checking database server", db_name)
5050 engine.check_database(db_conn)
5151
1414 # See the License for the specific language governing permissions and
1515 # limitations under the License.
1616
17 import calendar
1817 import logging
19 import time
2018 from typing import Any, Dict, List, Optional, Tuple
2119
2220 from synapse.api.constants import PresenceState
267265 self._stream_order_on_start = self.get_room_max_stream_ordering()
268266 self._min_stream_order_on_start = self.get_room_min_stream_ordering()
269267
270 # Used in _generate_user_daily_visits to keep track of progress
271 self._last_user_visit_update = self._get_start_of_day()
272
273268 def get_device_stream_token(self) -> int:
274269 return self._device_list_id_gen.get_current_token()
275270
288283 " last_user_sync_ts, status_msg, currently_active FROM presence_stream"
289284 " WHERE state != ?"
290285 )
291 sql = self.database_engine.convert_param_style(sql)
292286
293287 txn = db_conn.cursor()
294288 txn.execute(sql, (PresenceState.OFFLINE,))
299293 row["currently_active"] = bool(row["currently_active"])
300294
301295 return [UserPresenceState(**row) for row in rows]
302
303 async def count_daily_users(self) -> int:
304 """
305 Counts the number of users who used this homeserver in the last 24 hours.
306 """
307 yesterday = int(self._clock.time_msec()) - (1000 * 60 * 60 * 24)
308 return await self.db_pool.runInteraction(
309 "count_daily_users", self._count_users, yesterday
310 )
311
312 async def count_monthly_users(self) -> int:
313 """
314 Counts the number of users who used this homeserver in the last 30 days.
315 Note this method is intended for phonehome metrics only and is different
316 from the mau figure in synapse.storage.monthly_active_users which,
317 amongst other things, includes a 3 day grace period before a user counts.
318 """
319 thirty_days_ago = int(self._clock.time_msec()) - (1000 * 60 * 60 * 24 * 30)
320 return await self.db_pool.runInteraction(
321 "count_monthly_users", self._count_users, thirty_days_ago
322 )
323
324 def _count_users(self, txn, time_from):
325 """
326 Returns number of users seen in the past time_from period
327 """
328 sql = """
329 SELECT COALESCE(count(*), 0) FROM (
330 SELECT user_id FROM user_ips
331 WHERE last_seen > ?
332 GROUP BY user_id
333 ) u
334 """
335 txn.execute(sql, (time_from,))
336 (count,) = txn.fetchone()
337 return count
338
339 async def count_r30_users(self) -> Dict[str, int]:
340 """
341 Counts the number of 30 day retained users, defined as:-
342 * Users who have created their accounts more than 30 days ago
343 * Where last seen at most 30 days ago
344 * Where account creation and last_seen are > 30 days apart
345
346 Returns:
347 A mapping of counts globally as well as broken out by platform.
348 """
349
350 def _count_r30_users(txn):
351 thirty_days_in_secs = 86400 * 30
352 now = int(self._clock.time())
353 thirty_days_ago_in_secs = now - thirty_days_in_secs
354
355 sql = """
356 SELECT platform, COALESCE(count(*), 0) FROM (
357 SELECT
358 users.name, platform, users.creation_ts * 1000,
359 MAX(uip.last_seen)
360 FROM users
361 INNER JOIN (
362 SELECT
363 user_id,
364 last_seen,
365 CASE
366 WHEN user_agent LIKE '%%Android%%' THEN 'android'
367 WHEN user_agent LIKE '%%iOS%%' THEN 'ios'
368 WHEN user_agent LIKE '%%Electron%%' THEN 'electron'
369 WHEN user_agent LIKE '%%Mozilla%%' THEN 'web'
370 WHEN user_agent LIKE '%%Gecko%%' THEN 'web'
371 ELSE 'unknown'
372 END
373 AS platform
374 FROM user_ips
375 ) uip
376 ON users.name = uip.user_id
377 AND users.appservice_id is NULL
378 AND users.creation_ts < ?
379 AND uip.last_seen/1000 > ?
380 AND (uip.last_seen/1000) - users.creation_ts > 86400 * 30
381 GROUP BY users.name, platform, users.creation_ts
382 ) u GROUP BY platform
383 """
384
385 results = {}
386 txn.execute(sql, (thirty_days_ago_in_secs, thirty_days_ago_in_secs))
387
388 for row in txn:
389 if row[0] == "unknown":
390 pass
391 results[row[0]] = row[1]
392
393 sql = """
394 SELECT COALESCE(count(*), 0) FROM (
395 SELECT users.name, users.creation_ts * 1000,
396 MAX(uip.last_seen)
397 FROM users
398 INNER JOIN (
399 SELECT
400 user_id,
401 last_seen
402 FROM user_ips
403 ) uip
404 ON users.name = uip.user_id
405 AND appservice_id is NULL
406 AND users.creation_ts < ?
407 AND uip.last_seen/1000 > ?
408 AND (uip.last_seen/1000) - users.creation_ts > 86400 * 30
409 GROUP BY users.name, users.creation_ts
410 ) u
411 """
412
413 txn.execute(sql, (thirty_days_ago_in_secs, thirty_days_ago_in_secs))
414
415 (count,) = txn.fetchone()
416 results["all"] = count
417
418 return results
419
420 return await self.db_pool.runInteraction("count_r30_users", _count_r30_users)
421
422 def _get_start_of_day(self):
423 """
424 Returns millisecond unixtime for start of UTC day.
425 """
426 now = time.gmtime()
427 today_start = calendar.timegm((now.tm_year, now.tm_mon, now.tm_mday, 0, 0, 0))
428 return today_start * 1000
429
430 async def generate_user_daily_visits(self) -> None:
431 """
432 Generates daily visit data for use in cohort/ retention analysis
433 """
434
435 def _generate_user_daily_visits(txn):
436 logger.info("Calling _generate_user_daily_visits")
437 today_start = self._get_start_of_day()
438 a_day_in_milliseconds = 24 * 60 * 60 * 1000
439 now = self.clock.time_msec()
440
441 sql = """
442 INSERT INTO user_daily_visits (user_id, device_id, timestamp)
443 SELECT u.user_id, u.device_id, ?
444 FROM user_ips AS u
445 LEFT JOIN (
446 SELECT user_id, device_id, timestamp FROM user_daily_visits
447 WHERE timestamp = ?
448 ) udv
449 ON u.user_id = udv.user_id AND u.device_id=udv.device_id
450 INNER JOIN users ON users.name=u.user_id
451 WHERE last_seen > ? AND last_seen <= ?
452 AND udv.timestamp IS NULL AND users.is_guest=0
453 AND users.appservice_id IS NULL
454 GROUP BY u.user_id, u.device_id
455 """
456
457 # This means that the day has rolled over but there could still
458 # be entries from the previous day. There is an edge case
459 # where if the user logs in at 23:59 and overwrites their
460 # last_seen at 00:01 then they will not be counted in the
461 # previous day's stats - it is important that the query is run
462 # often to minimise this case.
463 if today_start > self._last_user_visit_update:
464 yesterday_start = today_start - a_day_in_milliseconds
465 txn.execute(
466 sql,
467 (
468 yesterday_start,
469 yesterday_start,
470 self._last_user_visit_update,
471 today_start,
472 ),
473 )
474 self._last_user_visit_update = today_start
475
476 txn.execute(
477 sql, (today_start, today_start, self._last_user_visit_update, now)
478 )
479 # Update _last_user_visit_update to now. The reason to do this
480 # rather just clamping to the beginning of the day is to limit
481 # the size of the join - meaning that the query can be run more
482 # frequently
483 self._last_user_visit_update = now
484
485 await self.db_pool.runInteraction(
486 "generate_user_daily_visits", _generate_user_daily_visits
487 )
488296
489297 async def get_users(self) -> List[Dict[str, Any]]:
490298 """Function to retrieve a list of users in users table.
1717 import logging
1818 from typing import Dict, List, Optional, Tuple
1919
20 from synapse.api.constants import AccountDataTypes
2021 from synapse.storage._base import SQLBaseStore, db_to_json
2122 from synapse.storage.database import DatabasePool
2223 from synapse.storage.util.id_generators import StreamIdGenerator
290291 self, ignored_user_id: str, ignorer_user_id: str, cache_context: _CacheContext
291292 ) -> bool:
292293 ignored_account_data = await self.get_global_account_data_by_type_for_user(
293 "m.ignored_user_list",
294 AccountDataTypes.IGNORED_USER_LIST,
294295 ignorer_user_id,
295296 on_invalidate=cache_context.invalidate,
296297 )
297298 if not ignored_account_data:
298299 return False
299300
300 return ignored_user_id in ignored_account_data.get("ignored_users", {})
301 try:
302 return ignored_user_id in ignored_account_data.get("ignored_users", {})
303 except TypeError:
304 # The type of the ignored_users field is invalid.
305 return False
301306
302307
303308 class AccountDataStore(AccountDataWorkerStore):
1414 # limitations under the License.
1515 import logging
1616 import re
17
18 from synapse.appservice import AppServiceTransaction
17 from typing import List
18
19 from synapse.appservice import ApplicationService, AppServiceTransaction
1920 from synapse.config.appservice import load_appservices
21 from synapse.events import EventBase
2022 from synapse.storage._base import SQLBaseStore, db_to_json
2123 from synapse.storage.database import DatabasePool
2224 from synapse.storage.databases.main.events_worker import EventsWorkerStore
25 from synapse.types import JsonDict
2326 from synapse.util import json_encoder
2427
2528 logger = logging.getLogger(__name__)
171174 "application_services_state", {"as_id": service.id}, {"state": state}
172175 )
173176
174 async def create_appservice_txn(self, service, events):
177 async def create_appservice_txn(
178 self,
179 service: ApplicationService,
180 events: List[EventBase],
181 ephemeral: List[JsonDict],
182 ) -> AppServiceTransaction:
175183 """Atomically creates a new transaction for this application service
176 with the given list of events.
177
178 Args:
179 service(ApplicationService): The service who the transaction is for.
180 events(list<Event>): A list of events to put in the transaction.
181 Returns:
182 AppServiceTransaction: A new transaction.
184 with the given list of events. Ephemeral events are NOT persisted to the
185 database and are not resent if a transaction is retried.
186
187 Args:
188 service: The service who the transaction is for.
189 events: A list of persistent events to put in the transaction.
190 ephemeral: A list of ephemeral events to put in the transaction.
191
192 Returns:
193 A new transaction.
183194 """
184195
185196 def _create_appservice_txn(txn):
206217 "VALUES(?,?,?)",
207218 (service.id, new_txn_id, event_ids),
208219 )
209 return AppServiceTransaction(service=service, id=new_txn_id, events=events)
220 return AppServiceTransaction(
221 service=service, id=new_txn_id, events=events, ephemeral=ephemeral
222 )
210223
211224 return await self.db_pool.runInteraction(
212225 "create_appservice_txn", _create_appservice_txn
295308
296309 events = await self.get_events_as_list(event_ids)
297310
298 return AppServiceTransaction(service=service, id=entry["txn_id"], events=events)
311 return AppServiceTransaction(
312 service=service, id=entry["txn_id"], events=events, ephemeral=[]
313 )
299314
300315 def _get_last_txn(self, txn, service_id):
301316 txn.execute(
319334 )
320335
321336 async def get_new_events_for_appservice(self, current_id, limit):
322 """Get all new evnets"""
337 """Get all new events for an appservice"""
323338
324339 def get_new_events_for_appservice_txn(txn):
325340 sql = (
350365
351366 return upper_bound, events
352367
368 async def get_type_stream_id_for_appservice(
369 self, service: ApplicationService, type: str
370 ) -> int:
371 if type not in ("read_receipt", "presence"):
372 raise ValueError(
373 "Expected type to be a valid application stream id type, got %s"
374 % (type,)
375 )
376
377 def get_type_stream_id_for_appservice_txn(txn):
378 stream_id_type = "%s_stream_id" % type
379 txn.execute(
380 # We do NOT want to escape `stream_id_type`.
381 "SELECT %s FROM application_services_state WHERE as_id=?"
382 % stream_id_type,
383 (service.id,),
384 )
385 last_stream_id = txn.fetchone()
386 if last_stream_id is None or last_stream_id[0] is None: # no row exists
387 return 0
388 else:
389 return int(last_stream_id[0])
390
391 return await self.db_pool.runInteraction(
392 "get_type_stream_id_for_appservice", get_type_stream_id_for_appservice_txn
393 )
394
395 async def set_type_stream_id_for_appservice(
396 self, service: ApplicationService, type: str, pos: int
397 ) -> None:
398 if type not in ("read_receipt", "presence"):
399 raise ValueError(
400 "Expected type to be a valid application stream id type, got %s"
401 % (type,)
402 )
403
404 def set_type_stream_id_for_appservice_txn(txn):
405 stream_id_type = "%s_stream_id" % type
406 txn.execute(
407 "UPDATE application_services_state SET %s = ? WHERE as_id=?"
408 % stream_id_type,
409 (pos, service.id),
410 )
411
412 await self.db_pool.runInteraction(
413 "set_type_stream_id_for_appservice", set_type_stream_id_for_appservice_txn
414 )
415
353416
354417 class ApplicationServiceTransactionStore(ApplicationServiceTransactionWorkerStore):
355418 # This is currently empty due to there not being any AS storage functions
1616 from typing import TYPE_CHECKING
1717
1818 from synapse.events.utils import prune_event_dict
19 from synapse.metrics.background_process_metrics import run_as_background_process
19 from synapse.metrics.background_process_metrics import wrap_as_background_process
2020 from synapse.storage._base import SQLBaseStore
2121 from synapse.storage.database import DatabasePool
2222 from synapse.storage.databases.main.cache import CacheInvalidationWorkerStore
23 from synapse.storage.databases.main.events import encode_json
2423 from synapse.storage.databases.main.events_worker import EventsWorkerStore
24 from synapse.util import json_encoder
2525
2626 if TYPE_CHECKING:
2727 from synapse.server import HomeServer
3434 def __init__(self, database: DatabasePool, db_conn, hs: "HomeServer"):
3535 super().__init__(database, db_conn, hs)
3636
37 def _censor_redactions():
38 return run_as_background_process(
39 "_censor_redactions", self._censor_redactions
40 )
41
42 if self.hs.config.redaction_retention_period is not None:
43 hs.get_clock().looping_call(_censor_redactions, 5 * 60 * 1000)
44
37 if (
38 hs.config.run_background_tasks
39 and self.hs.config.redaction_retention_period is not None
40 ):
41 hs.get_clock().looping_call(self._censor_redactions, 5 * 60 * 1000)
42
43 @wrap_as_background_process("_censor_redactions")
4544 async def _censor_redactions(self):
4645 """Censors all redactions older than the configured period that haven't
4746 been censored yet.
104103 and original_event.internal_metadata.is_redacted()
105104 ):
106105 # Redaction was allowed
107 pruned_json = encode_json(
106 pruned_json = json_encoder.encode(
108107 prune_event_dict(
109108 original_event.room_version, original_event.get_dict()
110109 )
170169 return
171170
172171 # Prune the event's dict then convert it to JSON.
173 pruned_json = encode_json(
172 pruned_json = json_encoder.encode(
174173 prune_event_dict(event.room_version, event.get_dict())
175174 )
176175
1818 from synapse.metrics.background_process_metrics import wrap_as_background_process
1919 from synapse.storage._base import SQLBaseStore
2020 from synapse.storage.database import DatabasePool, make_tuple_comparison_clause
21 from synapse.util.caches.descriptors import Cache
21 from synapse.util.caches.lrucache import LruCache
2222
2323 logger = logging.getLogger(__name__)
2424
350350 return updated
351351
352352
353 class ClientIpStore(ClientIpBackgroundUpdateStore):
353 class ClientIpWorkerStore(ClientIpBackgroundUpdateStore):
354354 def __init__(self, database: DatabasePool, db_conn, hs):
355
356 self.client_ip_last_seen = Cache(
357 name="client_ip_last_seen", keylen=4, max_entries=50000
358 )
359
360355 super().__init__(database, db_conn, hs)
361356
362357 self.user_ips_max_age = hs.config.user_ips_max_age
363358
364 # (user_id, access_token, ip,) -> (user_agent, device_id, last_seen)
365 self._batch_row_update = {}
366
367 self._client_ip_looper = self._clock.looping_call(
368 self._update_client_ips_batch, 5 * 1000
369 )
370 self.hs.get_reactor().addSystemEventTrigger(
371 "before", "shutdown", self._update_client_ips_batch
372 )
373
374 if self.user_ips_max_age:
359 if hs.config.run_background_tasks and self.user_ips_max_age:
375360 self._clock.looping_call(self._prune_old_user_ips, 5 * 1000)
376
377 async def insert_client_ip(
378 self, user_id, access_token, ip, user_agent, device_id, now=None
379 ):
380 if not now:
381 now = int(self._clock.time_msec())
382 key = (user_id, access_token, ip)
383
384 try:
385 last_seen = self.client_ip_last_seen.get(key)
386 except KeyError:
387 last_seen = None
388 await self.populate_monthly_active_users(user_id)
389 # Rate-limited inserts
390 if last_seen is not None and (now - last_seen) < LAST_SEEN_GRANULARITY:
391 return
392
393 self.client_ip_last_seen.prefill(key, now)
394
395 self._batch_row_update[key] = (user_agent, device_id, now)
396
397 @wrap_as_background_process("update_client_ips")
398 async def _update_client_ips_batch(self) -> None:
399
400 # If the DB pool has already terminated, don't try updating
401 if not self.db_pool.is_running():
402 return
403
404 to_update = self._batch_row_update
405 self._batch_row_update = {}
406
407 await self.db_pool.runInteraction(
408 "_update_client_ips_batch", self._update_client_ips_batch_txn, to_update
409 )
410
411 def _update_client_ips_batch_txn(self, txn, to_update):
412 if "user_ips" in self.db_pool._unsafe_to_upsert_tables or (
413 not self.database_engine.can_native_upsert
414 ):
415 self.database_engine.lock_table(txn, "user_ips")
416
417 for entry in to_update.items():
418 (user_id, access_token, ip), (user_agent, device_id, last_seen) = entry
419
420 try:
421 self.db_pool.simple_upsert_txn(
422 txn,
423 table="user_ips",
424 keyvalues={
425 "user_id": user_id,
426 "access_token": access_token,
427 "ip": ip,
428 },
429 values={
430 "user_agent": user_agent,
431 "device_id": device_id,
432 "last_seen": last_seen,
433 },
434 lock=False,
435 )
436
437 # Technically an access token might not be associated with
438 # a device so we need to check.
439 if device_id:
440 # this is always an update rather than an upsert: the row should
441 # already exist, and if it doesn't, that may be because it has been
442 # deleted, and we don't want to re-create it.
443 self.db_pool.simple_update_txn(
444 txn,
445 table="devices",
446 keyvalues={"user_id": user_id, "device_id": device_id},
447 updatevalues={
448 "user_agent": user_agent,
449 "last_seen": last_seen,
450 "ip": ip,
451 },
452 )
453 except Exception as e:
454 # Failed to upsert, log and continue
455 logger.error("Failed to insert client IP %r: %r", entry, e)
456
457 async def get_last_client_ip_by_device(
458 self, user_id: str, device_id: Optional[str]
459 ) -> Dict[Tuple[str, str], dict]:
460 """For each device_id listed, give the user_ip it was last seen on
461
462 Args:
463 user_id: The user to fetch devices for.
464 device_id: If None fetches all devices for the user
465
466 Returns:
467 A dictionary mapping a tuple of (user_id, device_id) to dicts, with
468 keys giving the column names from the devices table.
469 """
470
471 keyvalues = {"user_id": user_id}
472 if device_id is not None:
473 keyvalues["device_id"] = device_id
474
475 res = await self.db_pool.simple_select_list(
476 table="devices",
477 keyvalues=keyvalues,
478 retcols=("user_id", "ip", "user_agent", "device_id", "last_seen"),
479 )
480
481 ret = {(d["user_id"], d["device_id"]): d for d in res}
482 for key in self._batch_row_update:
483 uid, access_token, ip = key
484 if uid == user_id:
485 user_agent, did, last_seen = self._batch_row_update[key]
486 if not device_id or did == device_id:
487 ret[(user_id, device_id)] = {
488 "user_id": user_id,
489 "access_token": access_token,
490 "ip": ip,
491 "user_agent": user_agent,
492 "device_id": did,
493 "last_seen": last_seen,
494 }
495 return ret
496
497 async def get_user_ip_and_agents(self, user):
498 user_id = user.to_string()
499 results = {}
500
501 for key in self._batch_row_update:
502 uid, access_token, ip, = key
503 if uid == user_id:
504 user_agent, _, last_seen = self._batch_row_update[key]
505 results[(access_token, ip)] = (user_agent, last_seen)
506
507 rows = await self.db_pool.simple_select_list(
508 table="user_ips",
509 keyvalues={"user_id": user_id},
510 retcols=["access_token", "ip", "user_agent", "last_seen"],
511 desc="get_user_ip_and_agents",
512 )
513
514 results.update(
515 ((row["access_token"], row["ip"]), (row["user_agent"], row["last_seen"]))
516 for row in rows
517 )
518 return [
519 {
520 "access_token": access_token,
521 "ip": ip,
522 "user_agent": user_agent,
523 "last_seen": last_seen,
524 }
525 for (access_token, ip), (user_agent, last_seen) in results.items()
526 ]
527361
528362 @wrap_as_background_process("prune_old_user_ips")
529363 async def _prune_old_user_ips(self):
570404 await self.db_pool.runInteraction(
571405 "_prune_old_user_ips", _prune_old_user_ips_txn
572406 )
407
408
409 class ClientIpStore(ClientIpWorkerStore):
410 def __init__(self, database: DatabasePool, db_conn, hs):
411
412 self.client_ip_last_seen = LruCache(
413 cache_name="client_ip_last_seen", keylen=4, max_size=50000
414 )
415
416 super().__init__(database, db_conn, hs)
417
418 # (user_id, access_token, ip,) -> (user_agent, device_id, last_seen)
419 self._batch_row_update = {}
420
421 self._client_ip_looper = self._clock.looping_call(
422 self._update_client_ips_batch, 5 * 1000
423 )
424 self.hs.get_reactor().addSystemEventTrigger(
425 "before", "shutdown", self._update_client_ips_batch
426 )
427
428 async def insert_client_ip(
429 self, user_id, access_token, ip, user_agent, device_id, now=None
430 ):
431 if not now:
432 now = int(self._clock.time_msec())
433 key = (user_id, access_token, ip)
434
435 try:
436 last_seen = self.client_ip_last_seen.get(key)
437 except KeyError:
438 last_seen = None
439 await self.populate_monthly_active_users(user_id)
440 # Rate-limited inserts
441 if last_seen is not None and (now - last_seen) < LAST_SEEN_GRANULARITY:
442 return
443
444 self.client_ip_last_seen.set(key, now)
445
446 self._batch_row_update[key] = (user_agent, device_id, now)
447
448 @wrap_as_background_process("update_client_ips")
449 async def _update_client_ips_batch(self) -> None:
450
451 # If the DB pool has already terminated, don't try updating
452 if not self.db_pool.is_running():
453 return
454
455 to_update = self._batch_row_update
456 self._batch_row_update = {}
457
458 await self.db_pool.runInteraction(
459 "_update_client_ips_batch", self._update_client_ips_batch_txn, to_update
460 )
461
462 def _update_client_ips_batch_txn(self, txn, to_update):
463 if "user_ips" in self.db_pool._unsafe_to_upsert_tables or (
464 not self.database_engine.can_native_upsert
465 ):
466 self.database_engine.lock_table(txn, "user_ips")
467
468 for entry in to_update.items():
469 (user_id, access_token, ip), (user_agent, device_id, last_seen) = entry
470
471 try:
472 self.db_pool.simple_upsert_txn(
473 txn,
474 table="user_ips",
475 keyvalues={
476 "user_id": user_id,
477 "access_token": access_token,
478 "ip": ip,
479 },
480 values={
481 "user_agent": user_agent,
482 "device_id": device_id,
483 "last_seen": last_seen,
484 },
485 lock=False,
486 )
487
488 # Technically an access token might not be associated with
489 # a device so we need to check.
490 if device_id:
491 # this is always an update rather than an upsert: the row should
492 # already exist, and if it doesn't, that may be because it has been
493 # deleted, and we don't want to re-create it.
494 self.db_pool.simple_update_txn(
495 txn,
496 table="devices",
497 keyvalues={"user_id": user_id, "device_id": device_id},
498 updatevalues={
499 "user_agent": user_agent,
500 "last_seen": last_seen,
501 "ip": ip,
502 },
503 )
504 except Exception as e:
505 # Failed to upsert, log and continue
506 logger.error("Failed to insert client IP %r: %r", entry, e)
507
508 async def get_last_client_ip_by_device(
509 self, user_id: str, device_id: Optional[str]
510 ) -> Dict[Tuple[str, str], dict]:
511 """For each device_id listed, give the user_ip it was last seen on
512
513 Args:
514 user_id: The user to fetch devices for.
515 device_id: If None fetches all devices for the user
516
517 Returns:
518 A dictionary mapping a tuple of (user_id, device_id) to dicts, with
519 keys giving the column names from the devices table.
520 """
521
522 keyvalues = {"user_id": user_id}
523 if device_id is not None:
524 keyvalues["device_id"] = device_id
525
526 res = await self.db_pool.simple_select_list(
527 table="devices",
528 keyvalues=keyvalues,
529 retcols=("user_id", "ip", "user_agent", "device_id", "last_seen"),
530 )
531
532 ret = {(d["user_id"], d["device_id"]): d for d in res}
533 for key in self._batch_row_update:
534 uid, access_token, ip = key
535 if uid == user_id:
536 user_agent, did, last_seen = self._batch_row_update[key]
537 if not device_id or did == device_id:
538 ret[(user_id, device_id)] = {
539 "user_id": user_id,
540 "access_token": access_token,
541 "ip": ip,
542 "user_agent": user_agent,
543 "device_id": did,
544 "last_seen": last_seen,
545 }
546 return ret
547
548 async def get_user_ip_and_agents(self, user):
549 user_id = user.to_string()
550 results = {}
551
552 for key in self._batch_row_update:
553 uid, access_token, ip, = key
554 if uid == user_id:
555 user_agent, _, last_seen = self._batch_row_update[key]
556 results[(access_token, ip)] = (user_agent, last_seen)
557
558 rows = await self.db_pool.simple_select_list(
559 table="user_ips",
560 keyvalues={"user_id": user_id},
561 retcols=["access_token", "ip", "user_agent", "last_seen"],
562 desc="get_user_ip_and_agents",
563 )
564
565 results.update(
566 ((row["access_token"], row["ip"]), (row["user_agent"], row["last_seen"]))
567 for row in rows
568 )
569 return [
570 {
571 "access_token": access_token,
572 "ip": ip,
573 "user_agent": user_agent,
574 "last_seen": last_seen,
575 }
576 for (access_token, ip), (user_agent, last_seen) in results.items()
577 ]
00 # -*- coding: utf-8 -*-
11 # Copyright 2016 OpenMarket Ltd
22 # Copyright 2019 New Vector Ltd
3 # Copyright 2019 The Matrix.org Foundation C.I.C.
3 # Copyright 2019,2020 The Matrix.org Foundation C.I.C.
44 #
55 # Licensed under the Apache License, Version 2.0 (the "License");
66 # you may not use this file except in compliance with the License.
2424 trace,
2525 whitelisted_homeserver,
2626 )
27 from synapse.metrics.background_process_metrics import run_as_background_process
27 from synapse.metrics.background_process_metrics import wrap_as_background_process
2828 from synapse.storage._base import SQLBaseStore, db_to_json, make_in_list_sql_clause
2929 from synapse.storage.database import (
3030 DatabasePool,
3232 make_tuple_comparison_clause,
3333 )
3434 from synapse.types import Collection, JsonDict, get_verify_key_from_cross_signing_key
35 from synapse.util import json_encoder
36 from synapse.util.caches.descriptors import Cache, cached, cachedList
35 from synapse.util import json_decoder, json_encoder
36 from synapse.util.caches.descriptors import cached, cachedList
37 from synapse.util.caches.lrucache import LruCache
3738 from synapse.util.iterutils import batch_iter
3839 from synapse.util.stringutils import shortstr
3940
4748
4849
4950 class DeviceWorkerStore(SQLBaseStore):
51 def __init__(self, database: DatabasePool, db_conn, hs):
52 super().__init__(database, db_conn, hs)
53
54 if hs.config.run_background_tasks:
55 self._clock.looping_call(
56 self._prune_old_outbound_device_pokes, 60 * 60 * 1000
57 )
58
5059 async def get_device(self, user_id: str, device_id: str) -> Dict[str, Any]:
5160 """Retrieve a device. Only returns devices that are not marked as
5261 hidden.
697706 _mark_remote_user_device_list_as_unsubscribed_txn,
698707 )
699708
700
701 class DeviceBackgroundUpdateStore(SQLBaseStore):
702 def __init__(self, database: DatabasePool, db_conn, hs):
703 super().__init__(database, db_conn, hs)
704
705 self.db_pool.updates.register_background_index_update(
706 "device_lists_stream_idx",
707 index_name="device_lists_stream_user_id",
708 table="device_lists_stream",
709 columns=["user_id", "device_id"],
710 )
711
712 # create a unique index on device_lists_remote_cache
713 self.db_pool.updates.register_background_index_update(
714 "device_lists_remote_cache_unique_idx",
715 index_name="device_lists_remote_cache_unique_id",
716 table="device_lists_remote_cache",
717 columns=["user_id", "device_id"],
718 unique=True,
719 )
720
721 # And one on device_lists_remote_extremeties
722 self.db_pool.updates.register_background_index_update(
723 "device_lists_remote_extremeties_unique_idx",
724 index_name="device_lists_remote_extremeties_unique_idx",
725 table="device_lists_remote_extremeties",
726 columns=["user_id"],
727 unique=True,
728 )
729
730 # once they complete, we can remove the old non-unique indexes.
731 self.db_pool.updates.register_background_update_handler(
732 DROP_DEVICE_LIST_STREAMS_NON_UNIQUE_INDEXES,
733 self._drop_device_list_streams_non_unique_indexes,
734 )
735
736 # clear out duplicate device list outbound pokes
737 self.db_pool.updates.register_background_update_handler(
738 BG_UPDATE_REMOVE_DUP_OUTBOUND_POKES, self._remove_duplicate_outbound_pokes,
739 )
740
741 # a pair of background updates that were added during the 1.14 release cycle,
742 # but replaced with 58/06dlols_unique_idx.py
743 self.db_pool.updates.register_noop_background_update(
744 "device_lists_outbound_last_success_unique_idx",
745 )
746 self.db_pool.updates.register_noop_background_update(
747 "drop_device_lists_outbound_last_success_non_unique_idx",
748 )
749
750 async def _drop_device_list_streams_non_unique_indexes(self, progress, batch_size):
751 def f(conn):
752 txn = conn.cursor()
753 txn.execute("DROP INDEX IF EXISTS device_lists_remote_cache_id")
754 txn.execute("DROP INDEX IF EXISTS device_lists_remote_extremeties_id")
755 txn.close()
756
757 await self.db_pool.runWithConnection(f)
758 await self.db_pool.updates._end_background_update(
759 DROP_DEVICE_LIST_STREAMS_NON_UNIQUE_INDEXES
760 )
761 return 1
762
763 async def _remove_duplicate_outbound_pokes(self, progress, batch_size):
764 # for some reason, we have accumulated duplicate entries in
765 # device_lists_outbound_pokes, which makes prune_outbound_device_list_pokes less
766 # efficient.
767 #
768 # For each duplicate, we delete all the existing rows and put one back.
769
770 KEY_COLS = ["stream_id", "destination", "user_id", "device_id"]
771 last_row = progress.get(
772 "last_row",
773 {"stream_id": 0, "destination": "", "user_id": "", "device_id": ""},
774 )
775
776 def _txn(txn):
777 clause, args = make_tuple_comparison_clause(
778 self.db_pool.engine, [(x, last_row[x]) for x in KEY_COLS]
779 )
780 sql = """
781 SELECT stream_id, destination, user_id, device_id, MAX(ts) AS ts
782 FROM device_lists_outbound_pokes
783 WHERE %s
784 GROUP BY %s
785 HAVING count(*) > 1
786 ORDER BY %s
787 LIMIT ?
788 """ % (
789 clause, # WHERE
790 ",".join(KEY_COLS), # GROUP BY
791 ",".join(KEY_COLS), # ORDER BY
792 )
793 txn.execute(sql, args + [batch_size])
794 rows = self.db_pool.cursor_to_dict(txn)
795
796 row = None
797 for row in rows:
798 self.db_pool.simple_delete_txn(
799 txn, "device_lists_outbound_pokes", {x: row[x] for x in KEY_COLS},
800 )
801
802 row["sent"] = False
803 self.db_pool.simple_insert_txn(
804 txn, "device_lists_outbound_pokes", row,
805 )
806
807 if row:
808 self.db_pool.updates._background_update_progress_txn(
809 txn, BG_UPDATE_REMOVE_DUP_OUTBOUND_POKES, {"last_row": row},
810 )
811
812 return len(rows)
813
814 rows = await self.db_pool.runInteraction(
815 BG_UPDATE_REMOVE_DUP_OUTBOUND_POKES, _txn
816 )
817
818 if not rows:
819 await self.db_pool.updates._end_background_update(
820 BG_UPDATE_REMOVE_DUP_OUTBOUND_POKES
821 )
822
823 return rows
824
825
826 class DeviceStore(DeviceWorkerStore, DeviceBackgroundUpdateStore):
827 def __init__(self, database: DatabasePool, db_conn, hs):
828 super().__init__(database, db_conn, hs)
829
830 # Map of (user_id, device_id) -> bool. If there is an entry that implies
831 # the device exists.
832 self.device_id_exists_cache = Cache(
833 name="device_id_exists", keylen=2, max_entries=10000
834 )
835
836 self._clock.looping_call(self._prune_old_outbound_device_pokes, 60 * 60 * 1000)
837
838 async def store_device(
839 self, user_id: str, device_id: str, initial_device_display_name: str
840 ) -> bool:
841 """Ensure the given device is known; add it to the store if not
842
843 Args:
844 user_id: id of user associated with the device
845 device_id: id of device
846 initial_device_display_name: initial displayname of the device.
847 Ignored if device exists.
848
709 async def get_dehydrated_device(
710 self, user_id: str
711 ) -> Optional[Tuple[str, JsonDict]]:
712 """Retrieve the information for a dehydrated device.
713
714 Args:
715 user_id: the user whose dehydrated device we are looking for
849716 Returns:
850 Whether the device was inserted or an existing device existed with that ID.
851
852 Raises:
853 StoreError: if the device is already in use
854 """
855 key = (user_id, device_id)
856 if self.device_id_exists_cache.get(key, None):
857 return False
858
859 try:
860 inserted = await self.db_pool.simple_insert(
861 "devices",
862 values={
863 "user_id": user_id,
864 "device_id": device_id,
865 "display_name": initial_device_display_name,
866 "hidden": False,
867 },
868 desc="store_device",
869 or_ignore=True,
870 )
871 if not inserted:
872 # if the device already exists, check if it's a real device, or
873 # if the device ID is reserved by something else
874 hidden = await self.db_pool.simple_select_one_onecol(
875 "devices",
876 keyvalues={"user_id": user_id, "device_id": device_id},
877 retcol="hidden",
878 )
879 if hidden:
880 raise StoreError(400, "The device ID is in use", Codes.FORBIDDEN)
881 self.device_id_exists_cache.prefill(key, True)
882 return inserted
883 except StoreError:
884 raise
885 except Exception as e:
886 logger.error(
887 "store_device with device_id=%s(%r) user_id=%s(%r)"
888 " display_name=%s(%r) failed: %s",
889 type(device_id).__name__,
890 device_id,
891 type(user_id).__name__,
892 user_id,
893 type(initial_device_display_name).__name__,
894 initial_device_display_name,
895 e,
896 )
897 raise StoreError(500, "Problem storing device.")
898
899 async def delete_device(self, user_id: str, device_id: str) -> None:
900 """Delete a device.
901
902 Args:
903 user_id: The ID of the user which owns the device
904 device_id: The ID of the device to delete
905 """
906 await self.db_pool.simple_delete_one(
907 table="devices",
908 keyvalues={"user_id": user_id, "device_id": device_id, "hidden": False},
909 desc="delete_device",
910 )
911
912 self.device_id_exists_cache.invalidate((user_id, device_id))
913
914 async def delete_devices(self, user_id: str, device_ids: List[str]) -> None:
915 """Deletes several devices.
916
917 Args:
918 user_id: The ID of the user which owns the devices
919 device_ids: The IDs of the devices to delete
920 """
921 await self.db_pool.simple_delete_many(
922 table="devices",
923 column="device_id",
924 iterable=device_ids,
925 keyvalues={"user_id": user_id, "hidden": False},
926 desc="delete_devices",
927 )
928 for device_id in device_ids:
929 self.device_id_exists_cache.invalidate((user_id, device_id))
930
931 async def update_device(
932 self, user_id: str, device_id: str, new_display_name: Optional[str] = None
933 ) -> None:
934 """Update a device. Only updates the device if it is not marked as
935 hidden.
936
937 Args:
938 user_id: The ID of the user which owns the device
939 device_id: The ID of the device to update
940 new_display_name: new displayname for device; None to leave unchanged
941 Raises:
942 StoreError: if the device is not found
943 """
944 updates = {}
945 if new_display_name is not None:
946 updates["display_name"] = new_display_name
947 if not updates:
948 return None
949 await self.db_pool.simple_update_one(
950 table="devices",
951 keyvalues={"user_id": user_id, "device_id": device_id, "hidden": False},
952 updatevalues=updates,
953 desc="update_device",
954 )
955
956 async def update_remote_device_list_cache_entry(
957 self, user_id: str, device_id: str, content: JsonDict, stream_id: int
958 ) -> None:
959 """Updates a single device in the cache of a remote user's devicelist.
960
961 Note: assumes that we are the only thread that can be updating this user's
962 device list.
963
964 Args:
965 user_id: User to update device list for
966 device_id: ID of decivice being updated
967 content: new data on this device
968 stream_id: the version of the device list
969 """
970 await self.db_pool.runInteraction(
971 "update_remote_device_list_cache_entry",
972 self._update_remote_device_list_cache_entry_txn,
717 a tuple whose first item is the device ID, and the second item is
718 the dehydrated device information
719 """
720 # FIXME: make sure device ID still exists in devices table
721 row = await self.db_pool.simple_select_one(
722 table="dehydrated_devices",
723 keyvalues={"user_id": user_id},
724 retcols=["device_id", "device_data"],
725 allow_none=True,
726 )
727 return (
728 (row["device_id"], json_decoder.decode(row["device_data"])) if row else None
729 )
730
731 def _store_dehydrated_device_txn(
732 self, txn, user_id: str, device_id: str, device_data: str
733 ) -> Optional[str]:
734 old_device_id = self.db_pool.simple_select_one_onecol_txn(
735 txn,
736 table="dehydrated_devices",
737 keyvalues={"user_id": user_id},
738 retcol="device_id",
739 allow_none=True,
740 )
741 self.db_pool.simple_upsert_txn(
742 txn,
743 table="dehydrated_devices",
744 keyvalues={"user_id": user_id},
745 values={"device_id": device_id, "device_data": device_data},
746 )
747 return old_device_id
748
749 async def store_dehydrated_device(
750 self, user_id: str, device_id: str, device_data: JsonDict
751 ) -> Optional[str]:
752 """Store a dehydrated device for a user.
753
754 Args:
755 user_id: the user that we are storing the device for
756 device_id: the ID of the dehydrated device
757 device_data: the dehydrated device information
758 Returns:
759 device id of the user's previous dehydrated device, if any
760 """
761 return await self.db_pool.runInteraction(
762 "store_dehydrated_device_txn",
763 self._store_dehydrated_device_txn,
973764 user_id,
974765 device_id,
975 content,
976 stream_id,
977 )
978
979 def _update_remote_device_list_cache_entry_txn(
980 self,
981 txn: LoggingTransaction,
982 user_id: str,
983 device_id: str,
984 content: JsonDict,
985 stream_id: int,
766 json_encoder.encode(device_data),
767 )
768
769 async def remove_dehydrated_device(self, user_id: str, device_id: str) -> bool:
770 """Remove a dehydrated device.
771
772 Args:
773 user_id: the user that the dehydrated device belongs to
774 device_id: the ID of the dehydrated device
775 """
776 count = await self.db_pool.simple_delete(
777 "dehydrated_devices",
778 {"user_id": user_id, "device_id": device_id},
779 desc="remove_dehydrated_device",
780 )
781 return count >= 1
782
783 @wrap_as_background_process("prune_old_outbound_device_pokes")
784 async def _prune_old_outbound_device_pokes(
785 self, prune_age: int = 24 * 60 * 60 * 1000
986786 ) -> None:
987 if content.get("deleted"):
988 self.db_pool.simple_delete_txn(
989 txn,
990 table="device_lists_remote_cache",
991 keyvalues={"user_id": user_id, "device_id": device_id},
992 )
993
994 txn.call_after(self.device_id_exists_cache.invalidate, (user_id, device_id))
995 else:
996 self.db_pool.simple_upsert_txn(
997 txn,
998 table="device_lists_remote_cache",
999 keyvalues={"user_id": user_id, "device_id": device_id},
1000 values={"content": json_encoder.encode(content)},
1001 # we don't need to lock, because we assume we are the only thread
1002 # updating this user's devices.
1003 lock=False,
1004 )
1005
1006 txn.call_after(self._get_cached_user_device.invalidate, (user_id, device_id))
1007 txn.call_after(self.get_cached_devices_for_user.invalidate, (user_id,))
1008 txn.call_after(
1009 self.get_device_list_last_stream_id_for_remote.invalidate, (user_id,)
1010 )
1011
1012 self.db_pool.simple_upsert_txn(
1013 txn,
1014 table="device_lists_remote_extremeties",
1015 keyvalues={"user_id": user_id},
1016 values={"stream_id": stream_id},
1017 # again, we can assume we are the only thread updating this user's
1018 # extremity.
1019 lock=False,
1020 )
1021
1022 async def update_remote_device_list_cache(
1023 self, user_id: str, devices: List[dict], stream_id: int
1024 ) -> None:
1025 """Replace the entire cache of the remote user's devices.
1026
1027 Note: assumes that we are the only thread that can be updating this user's
1028 device list.
1029
1030 Args:
1031 user_id: User to update device list for
1032 devices: list of device objects supplied over federation
1033 stream_id: the version of the device list
1034 """
1035 await self.db_pool.runInteraction(
1036 "update_remote_device_list_cache",
1037 self._update_remote_device_list_cache_txn,
1038 user_id,
1039 devices,
1040 stream_id,
1041 )
1042
1043 def _update_remote_device_list_cache_txn(
1044 self, txn: LoggingTransaction, user_id: str, devices: List[dict], stream_id: int
1045 ) -> None:
1046 self.db_pool.simple_delete_txn(
1047 txn, table="device_lists_remote_cache", keyvalues={"user_id": user_id}
1048 )
1049
1050 self.db_pool.simple_insert_many_txn(
1051 txn,
1052 table="device_lists_remote_cache",
1053 values=[
1054 {
1055 "user_id": user_id,
1056 "device_id": content["device_id"],
1057 "content": json_encoder.encode(content),
1058 }
1059 for content in devices
1060 ],
1061 )
1062
1063 txn.call_after(self.get_cached_devices_for_user.invalidate, (user_id,))
1064 txn.call_after(self._get_cached_user_device.invalidate_many, (user_id,))
1065 txn.call_after(
1066 self.get_device_list_last_stream_id_for_remote.invalidate, (user_id,)
1067 )
1068
1069 self.db_pool.simple_upsert_txn(
1070 txn,
1071 table="device_lists_remote_extremeties",
1072 keyvalues={"user_id": user_id},
1073 values={"stream_id": stream_id},
1074 # we don't need to lock, because we can assume we are the only thread
1075 # updating this user's extremity.
1076 lock=False,
1077 )
1078
1079 # If we're replacing the remote user's device list cache presumably
1080 # we've done a full resync, so we remove the entry that says we need
1081 # to resync
1082 self.db_pool.simple_delete_txn(
1083 txn, table="device_lists_remote_resync", keyvalues={"user_id": user_id},
1084 )
1085
1086 async def add_device_change_to_streams(
1087 self, user_id: str, device_ids: Collection[str], hosts: List[str]
1088 ):
1089 """Persist that a user's devices have been updated, and which hosts
1090 (if any) should be poked.
1091 """
1092 if not device_ids:
1093 return
1094
1095 async with self._device_list_id_gen.get_next_mult(
1096 len(device_ids)
1097 ) as stream_ids:
1098 await self.db_pool.runInteraction(
1099 "add_device_change_to_stream",
1100 self._add_device_change_to_stream_txn,
1101 user_id,
1102 device_ids,
1103 stream_ids,
1104 )
1105
1106 if not hosts:
1107 return stream_ids[-1]
1108
1109 context = get_active_span_text_map()
1110 async with self._device_list_id_gen.get_next_mult(
1111 len(hosts) * len(device_ids)
1112 ) as stream_ids:
1113 await self.db_pool.runInteraction(
1114 "add_device_outbound_poke_to_stream",
1115 self._add_device_outbound_poke_to_stream_txn,
1116 user_id,
1117 device_ids,
1118 hosts,
1119 stream_ids,
1120 context,
1121 )
1122
1123 return stream_ids[-1]
1124
1125 def _add_device_change_to_stream_txn(
1126 self,
1127 txn: LoggingTransaction,
1128 user_id: str,
1129 device_ids: Collection[str],
1130 stream_ids: List[str],
1131 ):
1132 txn.call_after(
1133 self._device_list_stream_cache.entity_has_changed, user_id, stream_ids[-1],
1134 )
1135
1136 min_stream_id = stream_ids[0]
1137
1138 # Delete older entries in the table, as we really only care about
1139 # when the latest change happened.
1140 txn.executemany(
1141 """
1142 DELETE FROM device_lists_stream
1143 WHERE user_id = ? AND device_id = ? AND stream_id < ?
1144 """,
1145 [(user_id, device_id, min_stream_id) for device_id in device_ids],
1146 )
1147
1148 self.db_pool.simple_insert_many_txn(
1149 txn,
1150 table="device_lists_stream",
1151 values=[
1152 {"stream_id": stream_id, "user_id": user_id, "device_id": device_id}
1153 for stream_id, device_id in zip(stream_ids, device_ids)
1154 ],
1155 )
1156
1157 def _add_device_outbound_poke_to_stream_txn(
1158 self,
1159 txn: LoggingTransaction,
1160 user_id: str,
1161 device_ids: Collection[str],
1162 hosts: List[str],
1163 stream_ids: List[str],
1164 context: Dict[str, str],
1165 ):
1166 for host in hosts:
1167 txn.call_after(
1168 self._device_list_federation_stream_cache.entity_has_changed,
1169 host,
1170 stream_ids[-1],
1171 )
1172
1173 now = self._clock.time_msec()
1174 next_stream_id = iter(stream_ids)
1175
1176 self.db_pool.simple_insert_many_txn(
1177 txn,
1178 table="device_lists_outbound_pokes",
1179 values=[
1180 {
1181 "destination": destination,
1182 "stream_id": next(next_stream_id),
1183 "user_id": user_id,
1184 "device_id": device_id,
1185 "sent": False,
1186 "ts": now,
1187 "opentracing_context": json_encoder.encode(context)
1188 if whitelisted_homeserver(destination)
1189 else "{}",
1190 }
1191 for destination in hosts
1192 for device_id in device_ids
1193 ],
1194 )
1195
1196 def _prune_old_outbound_device_pokes(self, prune_age: int = 24 * 60 * 60 * 1000):
1197787 """Delete old entries out of the device_lists_outbound_pokes to ensure
1198788 that we don't fill up due to dead servers.
1199789
1278868
1279869 logger.info("Pruned %d device list outbound pokes", count)
1280870
1281 return run_as_background_process(
1282 "prune_old_outbound_device_pokes",
1283 self.db_pool.runInteraction,
1284 "_prune_old_outbound_device_pokes",
1285 _prune_txn,
1286 )
871 await self.db_pool.runInteraction(
872 "_prune_old_outbound_device_pokes", _prune_txn,
873 )
874
875
876 class DeviceBackgroundUpdateStore(SQLBaseStore):
877 def __init__(self, database: DatabasePool, db_conn, hs):
878 super().__init__(database, db_conn, hs)
879
880 self.db_pool.updates.register_background_index_update(
881 "device_lists_stream_idx",
882 index_name="device_lists_stream_user_id",
883 table="device_lists_stream",
884 columns=["user_id", "device_id"],
885 )
886
887 # create a unique index on device_lists_remote_cache
888 self.db_pool.updates.register_background_index_update(
889 "device_lists_remote_cache_unique_idx",
890 index_name="device_lists_remote_cache_unique_id",
891 table="device_lists_remote_cache",
892 columns=["user_id", "device_id"],
893 unique=True,
894 )
895
896 # And one on device_lists_remote_extremeties
897 self.db_pool.updates.register_background_index_update(
898 "device_lists_remote_extremeties_unique_idx",
899 index_name="device_lists_remote_extremeties_unique_idx",
900 table="device_lists_remote_extremeties",
901 columns=["user_id"],
902 unique=True,
903 )
904
905 # once they complete, we can remove the old non-unique indexes.
906 self.db_pool.updates.register_background_update_handler(
907 DROP_DEVICE_LIST_STREAMS_NON_UNIQUE_INDEXES,
908 self._drop_device_list_streams_non_unique_indexes,
909 )
910
911 # clear out duplicate device list outbound pokes
912 self.db_pool.updates.register_background_update_handler(
913 BG_UPDATE_REMOVE_DUP_OUTBOUND_POKES, self._remove_duplicate_outbound_pokes,
914 )
915
916 # a pair of background updates that were added during the 1.14 release cycle,
917 # but replaced with 58/06dlols_unique_idx.py
918 self.db_pool.updates.register_noop_background_update(
919 "device_lists_outbound_last_success_unique_idx",
920 )
921 self.db_pool.updates.register_noop_background_update(
922 "drop_device_lists_outbound_last_success_non_unique_idx",
923 )
924
925 async def _drop_device_list_streams_non_unique_indexes(self, progress, batch_size):
926 def f(conn):
927 txn = conn.cursor()
928 txn.execute("DROP INDEX IF EXISTS device_lists_remote_cache_id")
929 txn.execute("DROP INDEX IF EXISTS device_lists_remote_extremeties_id")
930 txn.close()
931
932 await self.db_pool.runWithConnection(f)
933 await self.db_pool.updates._end_background_update(
934 DROP_DEVICE_LIST_STREAMS_NON_UNIQUE_INDEXES
935 )
936 return 1
937
938 async def _remove_duplicate_outbound_pokes(self, progress, batch_size):
939 # for some reason, we have accumulated duplicate entries in
940 # device_lists_outbound_pokes, which makes prune_outbound_device_list_pokes less
941 # efficient.
942 #
943 # For each duplicate, we delete all the existing rows and put one back.
944
945 KEY_COLS = ["stream_id", "destination", "user_id", "device_id"]
946 last_row = progress.get(
947 "last_row",
948 {"stream_id": 0, "destination": "", "user_id": "", "device_id": ""},
949 )
950
951 def _txn(txn):
952 clause, args = make_tuple_comparison_clause(
953 self.db_pool.engine, [(x, last_row[x]) for x in KEY_COLS]
954 )
955 sql = """
956 SELECT stream_id, destination, user_id, device_id, MAX(ts) AS ts
957 FROM device_lists_outbound_pokes
958 WHERE %s
959 GROUP BY %s
960 HAVING count(*) > 1
961 ORDER BY %s
962 LIMIT ?
963 """ % (
964 clause, # WHERE
965 ",".join(KEY_COLS), # GROUP BY
966 ",".join(KEY_COLS), # ORDER BY
967 )
968 txn.execute(sql, args + [batch_size])
969 rows = self.db_pool.cursor_to_dict(txn)
970
971 row = None
972 for row in rows:
973 self.db_pool.simple_delete_txn(
974 txn, "device_lists_outbound_pokes", {x: row[x] for x in KEY_COLS},
975 )
976
977 row["sent"] = False
978 self.db_pool.simple_insert_txn(
979 txn, "device_lists_outbound_pokes", row,
980 )
981
982 if row:
983 self.db_pool.updates._background_update_progress_txn(
984 txn, BG_UPDATE_REMOVE_DUP_OUTBOUND_POKES, {"last_row": row},
985 )
986
987 return len(rows)
988
989 rows = await self.db_pool.runInteraction(
990 BG_UPDATE_REMOVE_DUP_OUTBOUND_POKES, _txn
991 )
992
993 if not rows:
994 await self.db_pool.updates._end_background_update(
995 BG_UPDATE_REMOVE_DUP_OUTBOUND_POKES
996 )
997
998 return rows
999
1000
1001 class DeviceStore(DeviceWorkerStore, DeviceBackgroundUpdateStore):
1002 def __init__(self, database: DatabasePool, db_conn, hs):
1003 super().__init__(database, db_conn, hs)
1004
1005 # Map of (user_id, device_id) -> bool. If there is an entry that implies
1006 # the device exists.
1007 self.device_id_exists_cache = LruCache(
1008 cache_name="device_id_exists", keylen=2, max_size=10000
1009 )
1010
1011 async def store_device(
1012 self, user_id: str, device_id: str, initial_device_display_name: Optional[str]
1013 ) -> bool:
1014 """Ensure the given device is known; add it to the store if not
1015
1016 Args:
1017 user_id: id of user associated with the device
1018 device_id: id of device
1019 initial_device_display_name: initial displayname of the device.
1020 Ignored if device exists.
1021
1022 Returns:
1023 Whether the device was inserted or an existing device existed with that ID.
1024
1025 Raises:
1026 StoreError: if the device is already in use
1027 """
1028 key = (user_id, device_id)
1029 if self.device_id_exists_cache.get(key, None):
1030 return False
1031
1032 try:
1033 inserted = await self.db_pool.simple_insert(
1034 "devices",
1035 values={
1036 "user_id": user_id,
1037 "device_id": device_id,
1038 "display_name": initial_device_display_name,
1039 "hidden": False,
1040 },
1041 desc="store_device",
1042 or_ignore=True,
1043 )
1044 if not inserted:
1045 # if the device already exists, check if it's a real device, or
1046 # if the device ID is reserved by something else
1047 hidden = await self.db_pool.simple_select_one_onecol(
1048 "devices",
1049 keyvalues={"user_id": user_id, "device_id": device_id},
1050 retcol="hidden",
1051 )
1052 if hidden:
1053 raise StoreError(400, "The device ID is in use", Codes.FORBIDDEN)
1054 self.device_id_exists_cache.set(key, True)
1055 return inserted
1056 except StoreError:
1057 raise
1058 except Exception as e:
1059 logger.error(
1060 "store_device with device_id=%s(%r) user_id=%s(%r)"
1061 " display_name=%s(%r) failed: %s",
1062 type(device_id).__name__,
1063 device_id,
1064 type(user_id).__name__,
1065 user_id,
1066 type(initial_device_display_name).__name__,
1067 initial_device_display_name,
1068 e,
1069 )
1070 raise StoreError(500, "Problem storing device.")
1071
1072 async def delete_device(self, user_id: str, device_id: str) -> None:
1073 """Delete a device.
1074
1075 Args:
1076 user_id: The ID of the user which owns the device
1077 device_id: The ID of the device to delete
1078 """
1079 await self.db_pool.simple_delete_one(
1080 table="devices",
1081 keyvalues={"user_id": user_id, "device_id": device_id, "hidden": False},
1082 desc="delete_device",
1083 )
1084
1085 self.device_id_exists_cache.invalidate((user_id, device_id))
1086
1087 async def delete_devices(self, user_id: str, device_ids: List[str]) -> None:
1088 """Deletes several devices.
1089
1090 Args:
1091 user_id: The ID of the user which owns the devices
1092 device_ids: The IDs of the devices to delete
1093 """
1094 await self.db_pool.simple_delete_many(
1095 table="devices",
1096 column="device_id",
1097 iterable=device_ids,
1098 keyvalues={"user_id": user_id, "hidden": False},
1099 desc="delete_devices",
1100 )
1101 for device_id in device_ids:
1102 self.device_id_exists_cache.invalidate((user_id, device_id))
1103
1104 async def update_device(
1105 self, user_id: str, device_id: str, new_display_name: Optional[str] = None
1106 ) -> None:
1107 """Update a device. Only updates the device if it is not marked as
1108 hidden.
1109
1110 Args:
1111 user_id: The ID of the user which owns the device
1112 device_id: The ID of the device to update
1113 new_display_name: new displayname for device; None to leave unchanged
1114 Raises:
1115 StoreError: if the device is not found
1116 """
1117 updates = {}
1118 if new_display_name is not None:
1119 updates["display_name"] = new_display_name
1120 if not updates:
1121 return None
1122 await self.db_pool.simple_update_one(
1123 table="devices",
1124 keyvalues={"user_id": user_id, "device_id": device_id, "hidden": False},
1125 updatevalues=updates,
1126 desc="update_device",
1127 )
1128
1129 async def update_remote_device_list_cache_entry(
1130 self, user_id: str, device_id: str, content: JsonDict, stream_id: str
1131 ) -> None:
1132 """Updates a single device in the cache of a remote user's devicelist.
1133
1134 Note: assumes that we are the only thread that can be updating this user's
1135 device list.
1136
1137 Args:
1138 user_id: User to update device list for
1139 device_id: ID of decivice being updated
1140 content: new data on this device
1141 stream_id: the version of the device list
1142 """
1143 await self.db_pool.runInteraction(
1144 "update_remote_device_list_cache_entry",
1145 self._update_remote_device_list_cache_entry_txn,
1146 user_id,
1147 device_id,
1148 content,
1149 stream_id,
1150 )
1151
1152 def _update_remote_device_list_cache_entry_txn(
1153 self,
1154 txn: LoggingTransaction,
1155 user_id: str,
1156 device_id: str,
1157 content: JsonDict,
1158 stream_id: str,
1159 ) -> None:
1160 if content.get("deleted"):
1161 self.db_pool.simple_delete_txn(
1162 txn,
1163 table="device_lists_remote_cache",
1164 keyvalues={"user_id": user_id, "device_id": device_id},
1165 )
1166
1167 txn.call_after(self.device_id_exists_cache.invalidate, (user_id, device_id))
1168 else:
1169 self.db_pool.simple_upsert_txn(
1170 txn,
1171 table="device_lists_remote_cache",
1172 keyvalues={"user_id": user_id, "device_id": device_id},
1173 values={"content": json_encoder.encode(content)},
1174 # we don't need to lock, because we assume we are the only thread
1175 # updating this user's devices.
1176 lock=False,
1177 )
1178
1179 txn.call_after(self._get_cached_user_device.invalidate, (user_id, device_id))
1180 txn.call_after(self.get_cached_devices_for_user.invalidate, (user_id,))
1181 txn.call_after(
1182 self.get_device_list_last_stream_id_for_remote.invalidate, (user_id,)
1183 )
1184
1185 self.db_pool.simple_upsert_txn(
1186 txn,
1187 table="device_lists_remote_extremeties",
1188 keyvalues={"user_id": user_id},
1189 values={"stream_id": stream_id},
1190 # again, we can assume we are the only thread updating this user's
1191 # extremity.
1192 lock=False,
1193 )
1194
1195 async def update_remote_device_list_cache(
1196 self, user_id: str, devices: List[dict], stream_id: int
1197 ) -> None:
1198 """Replace the entire cache of the remote user's devices.
1199
1200 Note: assumes that we are the only thread that can be updating this user's
1201 device list.
1202
1203 Args:
1204 user_id: User to update device list for
1205 devices: list of device objects supplied over federation
1206 stream_id: the version of the device list
1207 """
1208 await self.db_pool.runInteraction(
1209 "update_remote_device_list_cache",
1210 self._update_remote_device_list_cache_txn,
1211 user_id,
1212 devices,
1213 stream_id,
1214 )
1215
1216 def _update_remote_device_list_cache_txn(
1217 self, txn: LoggingTransaction, user_id: str, devices: List[dict], stream_id: int
1218 ) -> None:
1219 self.db_pool.simple_delete_txn(
1220 txn, table="device_lists_remote_cache", keyvalues={"user_id": user_id}
1221 )
1222
1223 self.db_pool.simple_insert_many_txn(
1224 txn,
1225 table="device_lists_remote_cache",
1226 values=[
1227 {
1228 "user_id": user_id,
1229 "device_id": content["device_id"],
1230 "content": json_encoder.encode(content),
1231 }
1232 for content in devices
1233 ],
1234 )
1235
1236 txn.call_after(self.get_cached_devices_for_user.invalidate, (user_id,))
1237 txn.call_after(self._get_cached_user_device.invalidate_many, (user_id,))
1238 txn.call_after(
1239 self.get_device_list_last_stream_id_for_remote.invalidate, (user_id,)
1240 )
1241
1242 self.db_pool.simple_upsert_txn(
1243 txn,
1244 table="device_lists_remote_extremeties",
1245 keyvalues={"user_id": user_id},
1246 values={"stream_id": stream_id},
1247 # we don't need to lock, because we can assume we are the only thread
1248 # updating this user's extremity.
1249 lock=False,
1250 )
1251
1252 # If we're replacing the remote user's device list cache presumably
1253 # we've done a full resync, so we remove the entry that says we need
1254 # to resync
1255 self.db_pool.simple_delete_txn(
1256 txn, table="device_lists_remote_resync", keyvalues={"user_id": user_id},
1257 )
1258
1259 async def add_device_change_to_streams(
1260 self, user_id: str, device_ids: Collection[str], hosts: List[str]
1261 ):
1262 """Persist that a user's devices have been updated, and which hosts
1263 (if any) should be poked.
1264 """
1265 if not device_ids:
1266 return
1267
1268 async with self._device_list_id_gen.get_next_mult(
1269 len(device_ids)
1270 ) as stream_ids:
1271 await self.db_pool.runInteraction(
1272 "add_device_change_to_stream",
1273 self._add_device_change_to_stream_txn,
1274 user_id,
1275 device_ids,
1276 stream_ids,
1277 )
1278
1279 if not hosts:
1280 return stream_ids[-1]
1281
1282 context = get_active_span_text_map()
1283 async with self._device_list_id_gen.get_next_mult(
1284 len(hosts) * len(device_ids)
1285 ) as stream_ids:
1286 await self.db_pool.runInteraction(
1287 "add_device_outbound_poke_to_stream",
1288 self._add_device_outbound_poke_to_stream_txn,
1289 user_id,
1290 device_ids,
1291 hosts,
1292 stream_ids,
1293 context,
1294 )
1295
1296 return stream_ids[-1]
1297
1298 def _add_device_change_to_stream_txn(
1299 self,
1300 txn: LoggingTransaction,
1301 user_id: str,
1302 device_ids: Collection[str],
1303 stream_ids: List[str],
1304 ):
1305 txn.call_after(
1306 self._device_list_stream_cache.entity_has_changed, user_id, stream_ids[-1],
1307 )
1308
1309 min_stream_id = stream_ids[0]
1310
1311 # Delete older entries in the table, as we really only care about
1312 # when the latest change happened.
1313 txn.executemany(
1314 """
1315 DELETE FROM device_lists_stream
1316 WHERE user_id = ? AND device_id = ? AND stream_id < ?
1317 """,
1318 [(user_id, device_id, min_stream_id) for device_id in device_ids],
1319 )
1320
1321 self.db_pool.simple_insert_many_txn(
1322 txn,
1323 table="device_lists_stream",
1324 values=[
1325 {"stream_id": stream_id, "user_id": user_id, "device_id": device_id}
1326 for stream_id, device_id in zip(stream_ids, device_ids)
1327 ],
1328 )
1329
1330 def _add_device_outbound_poke_to_stream_txn(
1331 self,
1332 txn: LoggingTransaction,
1333 user_id: str,
1334 device_ids: Collection[str],
1335 hosts: List[str],
1336 stream_ids: List[str],
1337 context: Dict[str, str],
1338 ):
1339 for host in hosts:
1340 txn.call_after(
1341 self._device_list_federation_stream_cache.entity_has_changed,
1342 host,
1343 stream_ids[-1],
1344 )
1345
1346 now = self._clock.time_msec()
1347 next_stream_id = iter(stream_ids)
1348
1349 self.db_pool.simple_insert_many_txn(
1350 txn,
1351 table="device_lists_outbound_pokes",
1352 values=[
1353 {
1354 "destination": destination,
1355 "stream_id": next(next_stream_id),
1356 "user_id": user_id,
1357 "device_id": device_id,
1358 "sent": False,
1359 "ts": now,
1360 "opentracing_context": json_encoder.encode(context)
1361 if whitelisted_homeserver(destination)
1362 else "{}",
1363 }
1364 for destination in hosts
1365 for device_id in device_ids
1366 ],
1367 )
00 # -*- coding: utf-8 -*-
11 # Copyright 2015, 2016 OpenMarket Ltd
22 # Copyright 2019 New Vector Ltd
3 # Copyright 2019 The Matrix.org Foundation C.I.C.
3 # Copyright 2019,2020 The Matrix.org Foundation C.I.C.
44 #
55 # Licensed under the Apache License, Version 2.0 (the "License");
66 # you may not use this file except in compliance with the License.
364364
365365 return await self.db_pool.runInteraction(
366366 "count_e2e_one_time_keys", _count_e2e_one_time_keys
367 )
368
369 async def set_e2e_fallback_keys(
370 self, user_id: str, device_id: str, fallback_keys: JsonDict
371 ) -> None:
372 """Set the user's e2e fallback keys.
373
374 Args:
375 user_id: the user whose keys are being set
376 device_id: the device whose keys are being set
377 fallback_keys: the keys to set. This is a map from key ID (which is
378 of the form "algorithm:id") to key data.
379 """
380 # fallback_keys will usually only have one item in it, so using a for
381 # loop (as opposed to calling simple_upsert_many_txn) won't be too bad
382 # FIXME: make sure that only one key per algorithm is uploaded
383 for key_id, fallback_key in fallback_keys.items():
384 algorithm, key_id = key_id.split(":", 1)
385 await self.db_pool.simple_upsert(
386 "e2e_fallback_keys_json",
387 keyvalues={
388 "user_id": user_id,
389 "device_id": device_id,
390 "algorithm": algorithm,
391 },
392 values={
393 "key_id": key_id,
394 "key_json": json_encoder.encode(fallback_key),
395 "used": False,
396 },
397 desc="set_e2e_fallback_key",
398 )
399
400 await self.invalidate_cache_and_stream(
401 "get_e2e_unused_fallback_key_types", (user_id, device_id)
402 )
403
404 @cached(max_entries=10000)
405 async def get_e2e_unused_fallback_key_types(
406 self, user_id: str, device_id: str
407 ) -> List[str]:
408 """Returns the fallback key types that have an unused key.
409
410 Args:
411 user_id: the user whose keys are being queried
412 device_id: the device whose keys are being queried
413
414 Returns:
415 a list of key types
416 """
417 return await self.db_pool.simple_select_onecol(
418 "e2e_fallback_keys_json",
419 keyvalues={"user_id": user_id, "device_id": device_id, "used": False},
420 retcol="algorithm",
421 desc="get_e2e_unused_fallback_key_types",
367422 )
368423
369424 async def get_e2e_cross_signing_key(
700755 " WHERE user_id = ? AND device_id = ? AND algorithm = ?"
701756 " LIMIT 1"
702757 )
758 fallback_sql = (
759 "SELECT key_id, key_json, used FROM e2e_fallback_keys_json"
760 " WHERE user_id = ? AND device_id = ? AND algorithm = ?"
761 " LIMIT 1"
762 )
703763 result = {}
704764 delete = []
765 used_fallbacks = []
705766 for user_id, device_id, algorithm in query_list:
706767 user_result = result.setdefault(user_id, {})
707768 device_result = user_result.setdefault(device_id, {})
708769 txn.execute(sql, (user_id, device_id, algorithm))
709 for key_id, key_json in txn:
770 otk_row = txn.fetchone()
771 if otk_row is not None:
772 key_id, key_json = otk_row
710773 device_result[algorithm + ":" + key_id] = key_json
711774 delete.append((user_id, device_id, algorithm, key_id))
775 else:
776 # no one-time key available, so see if there's a fallback
777 # key
778 txn.execute(fallback_sql, (user_id, device_id, algorithm))
779 fallback_row = txn.fetchone()
780 if fallback_row is not None:
781 key_id, key_json, used = fallback_row
782 device_result[algorithm + ":" + key_id] = key_json
783 if not used:
784 used_fallbacks.append(
785 (user_id, device_id, algorithm, key_id)
786 )
787
788 # drop any one-time keys that were claimed
712789 sql = (
713790 "DELETE FROM e2e_one_time_keys_json"
714791 " WHERE user_id = ? AND device_id = ? AND algorithm = ?"
725802 self._invalidate_cache_and_stream(
726803 txn, self.count_e2e_one_time_keys, (user_id, device_id)
727804 )
805 # mark fallback keys as used
806 for user_id, device_id, algorithm, key_id in used_fallbacks:
807 self.db_pool.simple_update_txn(
808 txn,
809 "e2e_fallback_keys_json",
810 {
811 "user_id": user_id,
812 "device_id": device_id,
813 "algorithm": algorithm,
814 "key_id": key_id,
815 },
816 {"used": True},
817 )
818 self._invalidate_cache_and_stream(
819 txn, self.get_e2e_unused_fallback_key_types, (user_id, device_id)
820 )
821
728822 return result
729823
730824 return await self.db_pool.runInteraction(
752846 )
753847 self._invalidate_cache_and_stream(
754848 txn, self.count_e2e_one_time_keys, (user_id, device_id)
849 )
850 self.db_pool.simple_delete_txn(
851 txn,
852 table="dehydrated_devices",
853 keyvalues={"user_id": user_id, "device_id": device_id},
854 )
855 self.db_pool.simple_delete_txn(
856 txn,
857 table="e2e_fallback_keys_json",
858 keyvalues={"user_id": user_id, "device_id": device_id},
859 )
860 self._invalidate_cache_and_stream(
861 txn, self.get_e2e_unused_fallback_key_types, (user_id, device_id)
755862 )
756863
757864 await self.db_pool.runInteraction(
1818
1919 from synapse.api.errors import StoreError
2020 from synapse.events import EventBase
21 from synapse.metrics.background_process_metrics import run_as_background_process
21 from synapse.metrics.background_process_metrics import wrap_as_background_process
2222 from synapse.storage._base import SQLBaseStore, make_in_list_sql_clause
2323 from synapse.storage.database import DatabasePool, LoggingTransaction
2424 from synapse.storage.databases.main.events_worker import EventsWorkerStore
3131
3232
3333 class EventFederationWorkerStore(EventsWorkerStore, SignatureWorkerStore, SQLBaseStore):
34 def __init__(self, database: DatabasePool, db_conn, hs):
35 super().__init__(database, db_conn, hs)
36
37 if hs.config.run_background_tasks:
38 hs.get_clock().looping_call(
39 self._delete_old_forward_extrem_cache, 60 * 60 * 1000
40 )
41
3442 async def get_auth_chain(
3543 self, event_ids: Collection[str], include_given: bool = False
3644 ) -> List[EventBase]:
585593
586594 return [row["event_id"] for row in rows]
587595
588
589 class EventFederationStore(EventFederationWorkerStore):
590 """ Responsible for storing and serving up the various graphs associated
591 with an event. Including the main event graph and the auth chains for an
592 event.
593
594 Also has methods for getting the front (latest) and back (oldest) edges
595 of the event graphs. These are used to generate the parents for new events
596 and backfilling from another server respectively.
597 """
598
599 EVENT_AUTH_STATE_ONLY = "event_auth_state_only"
600
601 def __init__(self, database: DatabasePool, db_conn, hs):
602 super().__init__(database, db_conn, hs)
603
604 self.db_pool.updates.register_background_update_handler(
605 self.EVENT_AUTH_STATE_ONLY, self._background_delete_non_state_event_auth
606 )
607
608 hs.get_clock().looping_call(
609 self._delete_old_forward_extrem_cache, 60 * 60 * 1000
610 )
611
612 def _delete_old_forward_extrem_cache(self):
596 @wrap_as_background_process("delete_old_forward_extrem_cache")
597 async def _delete_old_forward_extrem_cache(self) -> None:
613598 def _delete_old_forward_extrem_cache_txn(txn):
614599 # Delete entries older than a month, while making sure we don't delete
615600 # the only entries for a room.
626611 sql, (self.stream_ordering_month_ago, self.stream_ordering_month_ago)
627612 )
628613
629 return run_as_background_process(
630 "delete_old_forward_extrem_cache",
631 self.db_pool.runInteraction,
632 "_delete_old_forward_extrem_cache",
633 _delete_old_forward_extrem_cache_txn,
614 await self.db_pool.runInteraction(
615 "_delete_old_forward_extrem_cache", _delete_old_forward_extrem_cache_txn,
616 )
617
618
619 class EventFederationStore(EventFederationWorkerStore):
620 """ Responsible for storing and serving up the various graphs associated
621 with an event. Including the main event graph and the auth chains for an
622 event.
623
624 Also has methods for getting the front (latest) and back (oldest) edges
625 of the event graphs. These are used to generate the parents for new events
626 and backfilling from another server respectively.
627 """
628
629 EVENT_AUTH_STATE_ONLY = "event_auth_state_only"
630
631 def __init__(self, database: DatabasePool, db_conn, hs):
632 super().__init__(database, db_conn, hs)
633
634 self.db_pool.updates.register_background_update_handler(
635 self.EVENT_AUTH_STATE_ONLY, self._background_delete_non_state_event_auth
634636 )
635637
636638 async def clean_room_for_join(self, room_id):
1212 # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
1313 # See the License for the specific language governing permissions and
1414 # limitations under the License.
15
1615 import logging
1716 from typing import Dict, List, Optional, Tuple, Union
1817
1918 import attr
2019
21 from synapse.metrics.background_process_metrics import run_as_background_process
22 from synapse.storage._base import LoggingTransaction, SQLBaseStore, db_to_json
23 from synapse.storage.database import DatabasePool
20 from synapse.metrics.background_process_metrics import wrap_as_background_process
21 from synapse.storage._base import SQLBaseStore, db_to_json
22 from synapse.storage.database import DatabasePool, LoggingTransaction
2423 from synapse.util import json_encoder
2524 from synapse.util.caches.descriptors import cached
2625
7372 self.stream_ordering_month_ago = None
7473 self.stream_ordering_day_ago = None
7574
76 cur = LoggingTransaction(
77 db_conn.cursor(),
78 name="_find_stream_orderings_for_times_txn",
79 database_engine=self.database_engine,
80 )
75 cur = db_conn.cursor(txn_name="_find_stream_orderings_for_times_txn")
8176 self._find_stream_orderings_for_times_txn(cur)
8277 cur.close()
8378
8479 self.find_stream_orderings_looping_call = self._clock.looping_call(
8580 self._find_stream_orderings_for_times, 10 * 60 * 1000
8681 )
82
8783 self._rotate_delay = 3
8884 self._rotate_count = 10000
85 self._doing_notif_rotation = False
86 if hs.config.run_background_tasks:
87 self._rotate_notif_loop = self._clock.looping_call(
88 self._rotate_notifs, 30 * 60 * 1000
89 )
8990
9091 @cached(num_args=3, tree=True, max_entries=5000)
9192 async def get_unread_event_push_actions_by_room_for_user(
517518 "Error removing push actions after event persistence failure"
518519 )
519520
520 def _find_stream_orderings_for_times(self):
521 return run_as_background_process(
522 "event_push_action_stream_orderings",
523 self.db_pool.runInteraction,
521 @wrap_as_background_process("event_push_action_stream_orderings")
522 async def _find_stream_orderings_for_times(self) -> None:
523 await self.db_pool.runInteraction(
524524 "_find_stream_orderings_for_times",
525525 self._find_stream_orderings_for_times_txn,
526526 )
527527
528 def _find_stream_orderings_for_times_txn(self, txn):
528 def _find_stream_orderings_for_times_txn(self, txn: LoggingTransaction) -> None:
529529 logger.info("Searching for stream ordering 1 month ago")
530530 self.stream_ordering_month_ago = self._find_first_stream_ordering_after_ts_txn(
531531 txn, self._clock.time_msec() - 30 * 24 * 60 * 60 * 1000
655655 )
656656 return result[0] if result else None
657657
658
659 class EventPushActionsStore(EventPushActionsWorkerStore):
660 EPA_HIGHLIGHT_INDEX = "epa_highlight_index"
661
662 def __init__(self, database: DatabasePool, db_conn, hs):
663 super().__init__(database, db_conn, hs)
664
665 self.db_pool.updates.register_background_index_update(
666 self.EPA_HIGHLIGHT_INDEX,
667 index_name="event_push_actions_u_highlight",
668 table="event_push_actions",
669 columns=["user_id", "stream_ordering"],
670 )
671
672 self.db_pool.updates.register_background_index_update(
673 "event_push_actions_highlights_index",
674 index_name="event_push_actions_highlights_index",
675 table="event_push_actions",
676 columns=["user_id", "room_id", "topological_ordering", "stream_ordering"],
677 where_clause="highlight=1",
678 )
679
680 self._doing_notif_rotation = False
681 self._rotate_notif_loop = self._clock.looping_call(
682 self._start_rotate_notifs, 30 * 60 * 1000
683 )
684
685 async def get_push_actions_for_user(
686 self, user_id, before=None, limit=50, only_highlight=False
687 ):
688 def f(txn):
689 before_clause = ""
690 if before:
691 before_clause = "AND epa.stream_ordering < ?"
692 args = [user_id, before, limit]
693 else:
694 args = [user_id, limit]
695
696 if only_highlight:
697 if len(before_clause) > 0:
698 before_clause += " "
699 before_clause += "AND epa.highlight = 1"
700
701 # NB. This assumes event_ids are globally unique since
702 # it makes the query easier to index
703 sql = (
704 "SELECT epa.event_id, epa.room_id,"
705 " epa.stream_ordering, epa.topological_ordering,"
706 " epa.actions, epa.highlight, epa.profile_tag, e.received_ts"
707 " FROM event_push_actions epa, events e"
708 " WHERE epa.event_id = e.event_id"
709 " AND epa.user_id = ? %s"
710 " AND epa.notif = 1"
711 " ORDER BY epa.stream_ordering DESC"
712 " LIMIT ?" % (before_clause,)
713 )
714 txn.execute(sql, args)
715 return self.db_pool.cursor_to_dict(txn)
716
717 push_actions = await self.db_pool.runInteraction("get_push_actions_for_user", f)
718 for pa in push_actions:
719 pa["actions"] = _deserialize_action(pa["actions"], pa["highlight"])
720 return push_actions
721
722 async def get_latest_push_action_stream_ordering(self):
723 def f(txn):
724 txn.execute("SELECT MAX(stream_ordering) FROM event_push_actions")
725 return txn.fetchone()
726
727 result = await self.db_pool.runInteraction(
728 "get_latest_push_action_stream_ordering", f
729 )
730 return result[0] or 0
731
732 def _remove_old_push_actions_before_txn(
733 self, txn, room_id, user_id, stream_ordering
734 ):
735 """
736 Purges old push actions for a user and room before a given
737 stream_ordering.
738
739 We however keep a months worth of highlighted notifications, so that
740 users can still get a list of recent highlights.
741
742 Args:
743 txn: The transcation
744 room_id: Room ID to delete from
745 user_id: user ID to delete for
746 stream_ordering: The lowest stream ordering which will
747 not be deleted.
748 """
749 txn.call_after(
750 self.get_unread_event_push_actions_by_room_for_user.invalidate_many,
751 (room_id, user_id),
752 )
753
754 # We need to join on the events table to get the received_ts for
755 # event_push_actions and sqlite won't let us use a join in a delete so
756 # we can't just delete where received_ts < x. Furthermore we can
757 # only identify event_push_actions by a tuple of room_id, event_id
758 # we we can't use a subquery.
759 # Instead, we look up the stream ordering for the last event in that
760 # room received before the threshold time and delete event_push_actions
761 # in the room with a stream_odering before that.
762 txn.execute(
763 "DELETE FROM event_push_actions "
764 " WHERE user_id = ? AND room_id = ? AND "
765 " stream_ordering <= ?"
766 " AND ((stream_ordering < ? AND highlight = 1) or highlight = 0)",
767 (user_id, room_id, stream_ordering, self.stream_ordering_month_ago),
768 )
769
770 txn.execute(
771 """
772 DELETE FROM event_push_summary
773 WHERE room_id = ? AND user_id = ? AND stream_ordering <= ?
774 """,
775 (room_id, user_id, stream_ordering),
776 )
777
778 def _start_rotate_notifs(self):
779 return run_as_background_process("rotate_notifs", self._rotate_notifs)
780
658 @wrap_as_background_process("rotate_notifs")
781659 async def _rotate_notifs(self):
782660 if self._doing_notif_rotation or self.stream_ordering_day_ago is None:
783661 return
957835 )
958836
959837
838 class EventPushActionsStore(EventPushActionsWorkerStore):
839 EPA_HIGHLIGHT_INDEX = "epa_highlight_index"
840
841 def __init__(self, database: DatabasePool, db_conn, hs):
842 super().__init__(database, db_conn, hs)
843
844 self.db_pool.updates.register_background_index_update(
845 self.EPA_HIGHLIGHT_INDEX,
846 index_name="event_push_actions_u_highlight",
847 table="event_push_actions",
848 columns=["user_id", "stream_ordering"],
849 )
850
851 self.db_pool.updates.register_background_index_update(
852 "event_push_actions_highlights_index",
853 index_name="event_push_actions_highlights_index",
854 table="event_push_actions",
855 columns=["user_id", "room_id", "topological_ordering", "stream_ordering"],
856 where_clause="highlight=1",
857 )
858
859 async def get_push_actions_for_user(
860 self, user_id, before=None, limit=50, only_highlight=False
861 ):
862 def f(txn):
863 before_clause = ""
864 if before:
865 before_clause = "AND epa.stream_ordering < ?"
866 args = [user_id, before, limit]
867 else:
868 args = [user_id, limit]
869
870 if only_highlight:
871 if len(before_clause) > 0:
872 before_clause += " "
873 before_clause += "AND epa.highlight = 1"
874
875 # NB. This assumes event_ids are globally unique since
876 # it makes the query easier to index
877 sql = (
878 "SELECT epa.event_id, epa.room_id,"
879 " epa.stream_ordering, epa.topological_ordering,"
880 " epa.actions, epa.highlight, epa.profile_tag, e.received_ts"
881 " FROM event_push_actions epa, events e"
882 " WHERE epa.event_id = e.event_id"
883 " AND epa.user_id = ? %s"
884 " AND epa.notif = 1"
885 " ORDER BY epa.stream_ordering DESC"
886 " LIMIT ?" % (before_clause,)
887 )
888 txn.execute(sql, args)
889 return self.db_pool.cursor_to_dict(txn)
890
891 push_actions = await self.db_pool.runInteraction("get_push_actions_for_user", f)
892 for pa in push_actions:
893 pa["actions"] = _deserialize_action(pa["actions"], pa["highlight"])
894 return push_actions
895
896 async def get_latest_push_action_stream_ordering(self):
897 def f(txn):
898 txn.execute("SELECT MAX(stream_ordering) FROM event_push_actions")
899 return txn.fetchone()
900
901 result = await self.db_pool.runInteraction(
902 "get_latest_push_action_stream_ordering", f
903 )
904 return result[0] or 0
905
906 def _remove_old_push_actions_before_txn(
907 self, txn, room_id, user_id, stream_ordering
908 ):
909 """
910 Purges old push actions for a user and room before a given
911 stream_ordering.
912
913 We however keep a months worth of highlighted notifications, so that
914 users can still get a list of recent highlights.
915
916 Args:
917 txn: The transcation
918 room_id: Room ID to delete from
919 user_id: user ID to delete for
920 stream_ordering: The lowest stream ordering which will
921 not be deleted.
922 """
923 txn.call_after(
924 self.get_unread_event_push_actions_by_room_for_user.invalidate_many,
925 (room_id, user_id),
926 )
927
928 # We need to join on the events table to get the received_ts for
929 # event_push_actions and sqlite won't let us use a join in a delete so
930 # we can't just delete where received_ts < x. Furthermore we can
931 # only identify event_push_actions by a tuple of room_id, event_id
932 # we we can't use a subquery.
933 # Instead, we look up the stream ordering for the last event in that
934 # room received before the threshold time and delete event_push_actions
935 # in the room with a stream_odering before that.
936 txn.execute(
937 "DELETE FROM event_push_actions "
938 " WHERE user_id = ? AND room_id = ? AND "
939 " stream_ordering <= ?"
940 " AND ((stream_ordering < ? AND highlight = 1) or highlight = 0)",
941 (user_id, room_id, stream_ordering, self.stream_ordering_month_ago),
942 )
943
944 txn.execute(
945 """
946 DELETE FROM event_push_summary
947 WHERE room_id = ? AND user_id = ? AND stream_ordering <= ?
948 """,
949 (room_id, user_id, stream_ordering),
950 )
951
952
960953 def _action_has_highlight(actions):
961954 for action in actions:
962955 try:
3333 from synapse.storage.databases.main.search import SearchEntry
3434 from synapse.storage.util.id_generators import MultiWriterIdGenerator
3535 from synapse.types import StateMap, get_domain_from_id
36 from synapse.util.frozenutils import frozendict_json_encoder
36 from synapse.util import json_encoder
3737 from synapse.util.iterutils import batch_iter
3838
3939 if TYPE_CHECKING:
4949 "",
5050 ["type", "origin_type", "origin_entity"],
5151 )
52
53
54 def encode_json(json_object):
55 """
56 Encode a Python object as JSON and return it in a Unicode string.
57 """
58 out = frozendict_json_encoder.encode(json_object)
59 if isinstance(out, bytes):
60 out = out.decode("utf8")
61 return out
6252
6353
6454 _EventCacheEntry = namedtuple("_EventCacheEntry", ("event", "redacted_event"))
340330 min_stream_order = events_and_contexts[0][0].internal_metadata.stream_ordering
341331 max_stream_order = events_and_contexts[-1][0].internal_metadata.stream_ordering
342332
333 # stream orderings should have been assigned by now
334 assert min_stream_order
335 assert max_stream_order
336
343337 self._update_forward_extremities_txn(
344338 txn,
345339 new_forward_extremities=new_forward_extremeties,
365359 # seen before.
366360
367361 self._store_event_txn(txn, events_and_contexts=events_and_contexts)
362
363 self._persist_transaction_ids_txn(txn, events_and_contexts)
368364
369365 # Insert into event_to_state_groups.
370366 self._store_event_state_mappings_txn(txn, events_and_contexts)
410406 # room_memberships, where applicable.
411407 self._update_current_state_txn(txn, state_delta_for_room, min_stream_order)
412408
409 def _persist_transaction_ids_txn(
410 self,
411 txn: LoggingTransaction,
412 events_and_contexts: List[Tuple[EventBase, EventContext]],
413 ):
414 """Persist the mapping from transaction IDs to event IDs (if defined).
415 """
416
417 to_insert = []
418 for event, _ in events_and_contexts:
419 token_id = getattr(event.internal_metadata, "token_id", None)
420 txn_id = getattr(event.internal_metadata, "txn_id", None)
421 if token_id and txn_id:
422 to_insert.append(
423 {
424 "event_id": event.event_id,
425 "room_id": event.room_id,
426 "user_id": event.sender,
427 "token_id": token_id,
428 "txn_id": txn_id,
429 "inserted_ts": self._clock.time_msec(),
430 }
431 )
432
433 if to_insert:
434 self.db_pool.simple_insert_many_txn(
435 txn, table="event_txn_id", values=to_insert,
436 )
437
413438 def _update_current_state_txn(
414439 self,
415440 txn: LoggingTransaction,
431456 # so that async background tasks get told what happened.
432457 sql = """
433458 INSERT INTO current_state_delta_stream
434 (stream_id, room_id, type, state_key, event_id, prev_event_id)
435 SELECT ?, room_id, type, state_key, null, event_id
459 (stream_id, instance_name, room_id, type, state_key, event_id, prev_event_id)
460 SELECT ?, ?, room_id, type, state_key, null, event_id
436461 FROM current_state_events
437462 WHERE room_id = ?
438463 """
439 txn.execute(sql, (stream_id, room_id))
464 txn.execute(sql, (stream_id, self._instance_name, room_id))
440465
441466 self.db_pool.simple_delete_txn(
442467 txn, table="current_state_events", keyvalues={"room_id": room_id},
457482 #
458483 sql = """
459484 INSERT INTO current_state_delta_stream
460 (stream_id, room_id, type, state_key, event_id, prev_event_id)
461 SELECT ?, ?, ?, ?, ?, (
485 (stream_id, instance_name, room_id, type, state_key, event_id, prev_event_id)
486 SELECT ?, ?, ?, ?, ?, ?, (
462487 SELECT event_id FROM current_state_events
463488 WHERE room_id = ? AND type = ? AND state_key = ?
464489 )
468493 (
469494 (
470495 stream_id,
496 self._instance_name,
471497 room_id,
472498 etype,
473499 state_key,
742768 logger.exception("")
743769 raise
744770
745 metadata_json = encode_json(event.internal_metadata.get_dict())
771 metadata_json = json_encoder.encode(event.internal_metadata.get_dict())
746772
747773 sql = "UPDATE event_json SET internal_metadata = ? WHERE event_id = ?"
748774 txn.execute(sql, (metadata_json, event.event_id))
758784 "event_stream_ordering": stream_order,
759785 "event_id": event.event_id,
760786 "state_group": state_group_id,
787 "instance_name": self._instance_name,
761788 },
762789 )
763790
796823 {
797824 "event_id": event.event_id,
798825 "room_id": event.room_id,
799 "internal_metadata": encode_json(
826 "internal_metadata": json_encoder.encode(
800827 event.internal_metadata.get_dict()
801828 ),
802 "json": encode_json(event_dict(event)),
829 "json": json_encoder.encode(event_dict(event)),
803830 "format_version": event.format_version,
804831 }
805832 for event, _ in events_and_contexts
10211048
10221049 def prefill():
10231050 for cache_entry in to_prefill:
1024 self.store._get_event_cache.prefill(
1025 (cache_entry[0].event_id,), cache_entry
1026 )
1051 self.store._get_event_cache.set((cache_entry[0].event_id,), cache_entry)
10271052
10281053 txn.call_after(prefill)
10291054
12401265 )
12411266
12421267 def _store_retention_policy_for_room_txn(self, txn, event):
1268 if not event.is_state():
1269 logger.debug("Ignoring non-state m.room.retention event")
1270 return
1271
12431272 if hasattr(event, "content") and (
12441273 "min_lifetime" in event.content or "max_lifetime" in event.content
12451274 ):
1111 # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
1212 # See the License for the specific language governing permissions and
1313 # limitations under the License.
14
1514 import itertools
1615 import logging
1716 import threading
3332 from synapse.events import EventBase, make_event_from_dict
3433 from synapse.events.utils import prune_event
3534 from synapse.logging.context import PreserveLoggingContext, current_context
36 from synapse.metrics.background_process_metrics import run_as_background_process
35 from synapse.metrics.background_process_metrics import (
36 run_as_background_process,
37 wrap_as_background_process,
38 )
3739 from synapse.replication.slave.storage._slaved_id_tracker import SlavedIdTracker
3840 from synapse.replication.tcp.streams import BackfillStream
3941 from synapse.replication.tcp.streams.events import EventsStream
4244 from synapse.storage.engines import PostgresEngine
4345 from synapse.storage.util.id_generators import MultiWriterIdGenerator, StreamIdGenerator
4446 from synapse.types import Collection, get_domain_from_id
45 from synapse.util.caches.descriptors import Cache, cached
47 from synapse.util.caches.descriptors import cached
48 from synapse.util.caches.lrucache import LruCache
4649 from synapse.util.iterutils import batch_iter
4750 from synapse.util.metrics import Measure
4851
7376
7477
7578 class EventsWorkerStore(SQLBaseStore):
79 # Whether to use dedicated DB threads for event fetching. This is only used
80 # if there are multiple DB threads available. When used will lock the DB
81 # thread for periods of time (so unit tests want to disable this when they
82 # run DB transactions on the main thread). See EVENT_QUEUE_* for more
83 # options controlling this.
84 USE_DEDICATED_DB_THREADS_FOR_EVENT_FETCHING = True
85
7686 def __init__(self, database: DatabasePool, db_conn, hs):
7787 super().__init__(database, db_conn, hs)
7888
129139 db_conn, "events", "stream_ordering", step=-1
130140 )
131141
132 self._get_event_cache = Cache(
133 "*getEvent*",
142 if hs.config.run_background_tasks:
143 # We periodically clean out old transaction ID mappings
144 self._clock.looping_call(
145 self._cleanup_old_transaction_ids, 5 * 60 * 1000,
146 )
147
148 self._get_event_cache = LruCache(
149 cache_name="*getEvent*",
134150 keylen=3,
135 max_entries=hs.config.caches.event_cache_size,
136 apply_cache_factor_from_config=False,
151 max_size=hs.config.caches.event_cache_size,
137152 )
138153
139154 self._event_fetch_lock = threading.Condition()
521536
522537 if not event_list:
523538 single_threaded = self.database_engine.single_threaded
524 if single_threaded or i > EVENT_QUEUE_ITERATIONS:
539 if (
540 not self.USE_DEDICATED_DB_THREADS_FOR_EVENT_FETCHING
541 or single_threaded
542 or i > EVENT_QUEUE_ITERATIONS
543 ):
525544 self._event_fetch_ongoing -= 1
526545 return
527546 else:
711730 internal_metadata_dict=internal_metadata,
712731 rejected_reason=rejected_reason,
713732 )
733 original_ev.internal_metadata.stream_ordering = row["stream_ordering"]
714734
715735 event_map[event_id] = original_ev
716736
727747 event=original_ev, redacted_event=redacted_event
728748 )
729749
730 self._get_event_cache.prefill((event_id,), cache_entry)
750 self._get_event_cache.set((event_id,), cache_entry)
731751 result_map[event_id] = cache_entry
732752
733753 return result_map
778798
779799 * event_id (str)
780800
801 * stream_ordering (int): stream ordering for this event
802
781803 * json (str): json-encoded event structure
782804
783805 * internal_metadata (str): json-encoded internal metadata dict
810832 sql = """\
811833 SELECT
812834 e.event_id,
813 e.internal_metadata,
814 e.json,
815 e.format_version,
835 e.stream_ordering,
836 ej.internal_metadata,
837 ej.json,
838 ej.format_version,
816839 r.room_version,
817840 rej.reason
818 FROM event_json as e
819 LEFT JOIN rooms r USING (room_id)
841 FROM events AS e
842 JOIN event_json AS ej USING (event_id)
843 LEFT JOIN rooms r ON r.room_id = e.room_id
820844 LEFT JOIN rejections as rej USING (event_id)
821845 WHERE """
822846
830854 event_id = row[0]
831855 event_dict[event_id] = {
832856 "event_id": event_id,
833 "internal_metadata": row[1],
834 "json": row[2],
835 "format_version": row[3],
836 "room_version_id": row[4],
837 "rejected_reason": row[5],
857 "stream_ordering": row[1],
858 "internal_metadata": row[2],
859 "json": row[3],
860 "format_version": row[4],
861 "room_version_id": row[5],
862 "rejected_reason": row[6],
838863 "redactions": [],
839864 }
840865
10161041
10171042 return {"v1": complexity_v1}
10181043
1019 def get_current_backfill_token(self):
1020 """The current minimum token that backfilled events have reached"""
1021 return -self._backfill_id_gen.get_current_token()
1022
10231044 def get_current_events_token(self):
10241045 """The current maximum token that events have reached"""
10251046 return self._stream_id_gen.get_current_token()
10261047
10271048 async def get_all_new_forward_event_rows(
1028 self, last_id: int, current_id: int, limit: int
1049 self, instance_name: str, last_id: int, current_id: int, limit: int
10291050 ) -> List[Tuple]:
10301051 """Returns new events, for the Events replication stream
10311052
10491070 " LEFT JOIN state_events USING (event_id)"
10501071 " LEFT JOIN event_relations USING (event_id)"
10511072 " WHERE ? < stream_ordering AND stream_ordering <= ?"
1073 " AND instance_name = ?"
10521074 " ORDER BY stream_ordering ASC"
10531075 " LIMIT ?"
10541076 )
1055 txn.execute(sql, (last_id, current_id, limit))
1077 txn.execute(sql, (last_id, current_id, instance_name, limit))
10561078 return txn.fetchall()
10571079
10581080 return await self.db_pool.runInteraction(
10601082 )
10611083
10621084 async def get_ex_outlier_stream_rows(
1063 self, last_id: int, current_id: int
1085 self, instance_name: str, last_id: int, current_id: int
10641086 ) -> List[Tuple]:
10651087 """Returns de-outliered events, for the Events replication stream
10661088
10791101 "SELECT event_stream_ordering, e.event_id, e.room_id, e.type,"
10801102 " state_key, redacts, relates_to_id"
10811103 " FROM events AS e"
1082 " INNER JOIN ex_outlier_stream USING (event_id)"
1104 " INNER JOIN ex_outlier_stream AS out USING (event_id)"
10831105 " LEFT JOIN redactions USING (event_id)"
10841106 " LEFT JOIN state_events USING (event_id)"
10851107 " LEFT JOIN event_relations USING (event_id)"
10861108 " WHERE ? < event_stream_ordering"
10871109 " AND event_stream_ordering <= ?"
1110 " AND out.instance_name = ?"
10881111 " ORDER BY event_stream_ordering ASC"
10891112 )
10901113
1091 txn.execute(sql, (last_id, current_id))
1114 txn.execute(sql, (last_id, current_id, instance_name))
10921115 return txn.fetchall()
10931116
10941117 return await self.db_pool.runInteraction(
11001123 ) -> Tuple[List[Tuple[int, list]], int, bool]:
11011124 """Get updates for backfill replication stream, including all new
11021125 backfilled events and events that have gone from being outliers to not.
1126
1127 NOTE: The IDs given here are from replication, and so should be
1128 *positive*.
11031129
11041130 Args:
11051131 instance_name: The writer we want to fetch updates from. Unused
11311157 " LEFT JOIN state_events USING (event_id)"
11321158 " LEFT JOIN event_relations USING (event_id)"
11331159 " WHERE ? > stream_ordering AND stream_ordering >= ?"
1160 " AND instance_name = ?"
11341161 " ORDER BY stream_ordering ASC"
11351162 " LIMIT ?"
11361163 )
1137 txn.execute(sql, (-last_id, -current_id, limit))
1164 txn.execute(sql, (-last_id, -current_id, instance_name, limit))
11381165 new_event_updates = [(row[0], row[1:]) for row in txn]
11391166
11401167 limited = False
11481175 "SELECT -event_stream_ordering, e.event_id, e.room_id, e.type,"
11491176 " state_key, redacts, relates_to_id"
11501177 " FROM events AS e"
1151 " INNER JOIN ex_outlier_stream USING (event_id)"
1178 " INNER JOIN ex_outlier_stream AS out USING (event_id)"
11521179 " LEFT JOIN redactions USING (event_id)"
11531180 " LEFT JOIN state_events USING (event_id)"
11541181 " LEFT JOIN event_relations USING (event_id)"
11551182 " WHERE ? > event_stream_ordering"
11561183 " AND event_stream_ordering >= ?"
1184 " AND out.instance_name = ?"
11571185 " ORDER BY event_stream_ordering DESC"
11581186 )
1159 txn.execute(sql, (-last_id, -upper_bound))
1187 txn.execute(sql, (-last_id, -upper_bound, instance_name))
11601188 new_event_updates.extend((row[0], row[1:]) for row in txn)
11611189
11621190 if len(new_event_updates) >= limit:
11701198 )
11711199
11721200 async def get_all_updated_current_state_deltas(
1173 self, from_token: int, to_token: int, target_row_count: int
1201 self, instance_name: str, from_token: int, to_token: int, target_row_count: int
11741202 ) -> Tuple[List[Tuple], int, bool]:
11751203 """Fetch updates from current_state_delta_stream
11761204
11961224 SELECT stream_id, room_id, type, state_key, event_id
11971225 FROM current_state_delta_stream
11981226 WHERE ? < stream_id AND stream_id <= ?
1227 AND instance_name = ?
11991228 ORDER BY stream_id ASC LIMIT ?
12001229 """
1201 txn.execute(sql, (from_token, to_token, target_row_count))
1230 txn.execute(sql, (from_token, to_token, instance_name, target_row_count))
12021231 return txn.fetchall()
12031232
12041233 def get_deltas_for_stream_id_txn(txn, stream_id):
12861315 return await self.db_pool.runInteraction(
12871316 desc="get_next_event_to_expire", func=get_next_event_to_expire_txn
12881317 )
1318
1319 async def get_event_id_from_transaction_id(
1320 self, room_id: str, user_id: str, token_id: int, txn_id: str
1321 ) -> Optional[str]:
1322 """Look up if we have already persisted an event for the transaction ID,
1323 returning the event ID if so.
1324 """
1325 return await self.db_pool.simple_select_one_onecol(
1326 table="event_txn_id",
1327 keyvalues={
1328 "room_id": room_id,
1329 "user_id": user_id,
1330 "token_id": token_id,
1331 "txn_id": txn_id,
1332 },
1333 retcol="event_id",
1334 allow_none=True,
1335 desc="get_event_id_from_transaction_id",
1336 )
1337
1338 async def get_already_persisted_events(
1339 self, events: Iterable[EventBase]
1340 ) -> Dict[str, str]:
1341 """Look up if we have already persisted an event for the transaction ID,
1342 returning a mapping from event ID in the given list to the event ID of
1343 an existing event.
1344
1345 Also checks if there are duplicates in the given events, if there are
1346 will map duplicates to the *first* event.
1347 """
1348
1349 mapping = {}
1350 txn_id_to_event = {} # type: Dict[Tuple[str, int, str], str]
1351
1352 for event in events:
1353 token_id = getattr(event.internal_metadata, "token_id", None)
1354 txn_id = getattr(event.internal_metadata, "txn_id", None)
1355
1356 if token_id and txn_id:
1357 # Check if this is a duplicate of an event in the given events.
1358 existing = txn_id_to_event.get((event.room_id, token_id, txn_id))
1359 if existing:
1360 mapping[event.event_id] = existing
1361 continue
1362
1363 # Check if this is a duplicate of an event we've already
1364 # persisted.
1365 existing = await self.get_event_id_from_transaction_id(
1366 event.room_id, event.sender, token_id, txn_id
1367 )
1368 if existing:
1369 mapping[event.event_id] = existing
1370 txn_id_to_event[(event.room_id, token_id, txn_id)] = existing
1371 else:
1372 txn_id_to_event[(event.room_id, token_id, txn_id)] = event.event_id
1373
1374 return mapping
1375
1376 @wrap_as_background_process("_cleanup_old_transaction_ids")
1377 async def _cleanup_old_transaction_ids(self):
1378 """Cleans out transaction id mappings older than 24hrs.
1379 """
1380
1381 def _cleanup_old_transaction_ids_txn(txn):
1382 sql = """
1383 DELETE FROM event_txn_id
1384 WHERE inserted_ts < ?
1385 """
1386 one_day_ago = self._clock.time_msec() - 24 * 60 * 60 * 1000
1387 txn.execute(sql, (one_day_ago,))
1388
1389 return await self.db_pool.runInteraction(
1390 "_cleanup_old_transaction_ids", _cleanup_old_transaction_ids_txn,
1391 )
121121 # param, which is itself the 2-tuple (server_name, key_id).
122122 invalidations.append((server_name, key_id))
123123
124 await self.db_pool.runInteraction(
125 "store_server_verify_keys",
126 self.db_pool.simple_upsert_many_txn,
124 await self.db_pool.simple_upsert_many(
127125 table="server_signature_keys",
128126 key_names=("server_name", "key_id"),
129127 key_values=key_values,
134132 "verify_key",
135133 ),
136134 value_values=value_values,
135 desc="store_server_verify_keys",
137136 )
138137
139138 invalidate = self._get_server_verify_key.invalidate
1111 # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
1212 # See the License for the specific language governing permissions and
1313 # limitations under the License.
14 import calendar
15 import logging
16 import time
17 from typing import Dict
1418
1519 from synapse.metrics import GaugeBucketCollector
16 from synapse.metrics.background_process_metrics import run_as_background_process
20 from synapse.metrics.background_process_metrics import wrap_as_background_process
1721 from synapse.storage._base import SQLBaseStore
1822 from synapse.storage.database import DatabasePool
1923 from synapse.storage.databases.main.event_push_actions import (
2024 EventPushActionsWorkerStore,
2125 )
26
27 logger = logging.getLogger(__name__)
2228
2329 # Collect metrics on the number of forward extremities that exist.
2430 _extremities_collecter = GaugeBucketCollector(
5056 super().__init__(database, db_conn, hs)
5157
5258 # Read the extrems every 60 minutes
53 def read_forward_extremities():
54 # run as a background process to make sure that the database transactions
55 # have a logcontext to report to
56 return run_as_background_process(
57 "read_forward_extremities", self._read_forward_extremities
58 )
59
60 hs.get_clock().looping_call(read_forward_extremities, 60 * 60 * 1000)
61
59 if hs.config.run_background_tasks:
60 self._clock.looping_call(self._read_forward_extremities, 60 * 60 * 1000)
61
62 # Used in _generate_user_daily_visits to keep track of progress
63 self._last_user_visit_update = self._get_start_of_day()
64
65 @wrap_as_background_process("read_forward_extremities")
6266 async def _read_forward_extremities(self):
6367 def fetch(txn):
6468 txn.execute(
136140 return count
137141
138142 return await self.db_pool.runInteraction("count_daily_active_rooms", _count)
143
144 async def count_daily_users(self) -> int:
145 """
146 Counts the number of users who used this homeserver in the last 24 hours.
147 """
148 yesterday = int(self._clock.time_msec()) - (1000 * 60 * 60 * 24)
149 return await self.db_pool.runInteraction(
150 "count_daily_users", self._count_users, yesterday
151 )
152
153 async def count_monthly_users(self) -> int:
154 """
155 Counts the number of users who used this homeserver in the last 30 days.
156 Note this method is intended for phonehome metrics only and is different
157 from the mau figure in synapse.storage.monthly_active_users which,
158 amongst other things, includes a 3 day grace period before a user counts.
159 """
160 thirty_days_ago = int(self._clock.time_msec()) - (1000 * 60 * 60 * 24 * 30)
161 return await self.db_pool.runInteraction(
162 "count_monthly_users", self._count_users, thirty_days_ago
163 )
164
165 def _count_users(self, txn, time_from):
166 """
167 Returns number of users seen in the past time_from period
168 """
169 sql = """
170 SELECT COALESCE(count(*), 0) FROM (
171 SELECT user_id FROM user_ips
172 WHERE last_seen > ?
173 GROUP BY user_id
174 ) u
175 """
176 txn.execute(sql, (time_from,))
177 (count,) = txn.fetchone()
178 return count
179
180 async def count_r30_users(self) -> Dict[str, int]:
181 """
182 Counts the number of 30 day retained users, defined as:-
183 * Users who have created their accounts more than 30 days ago
184 * Where last seen at most 30 days ago
185 * Where account creation and last_seen are > 30 days apart
186
187 Returns:
188 A mapping of counts globally as well as broken out by platform.
189 """
190
191 def _count_r30_users(txn):
192 thirty_days_in_secs = 86400 * 30
193 now = int(self._clock.time())
194 thirty_days_ago_in_secs = now - thirty_days_in_secs
195
196 sql = """
197 SELECT platform, COALESCE(count(*), 0) FROM (
198 SELECT
199 users.name, platform, users.creation_ts * 1000,
200 MAX(uip.last_seen)
201 FROM users
202 INNER JOIN (
203 SELECT
204 user_id,
205 last_seen,
206 CASE
207 WHEN user_agent LIKE '%%Android%%' THEN 'android'
208 WHEN user_agent LIKE '%%iOS%%' THEN 'ios'
209 WHEN user_agent LIKE '%%Electron%%' THEN 'electron'
210 WHEN user_agent LIKE '%%Mozilla%%' THEN 'web'
211 WHEN user_agent LIKE '%%Gecko%%' THEN 'web'
212 ELSE 'unknown'
213 END
214 AS platform
215 FROM user_ips
216 ) uip
217 ON users.name = uip.user_id
218 AND users.appservice_id is NULL
219 AND users.creation_ts < ?
220 AND uip.last_seen/1000 > ?
221 AND (uip.last_seen/1000) - users.creation_ts > 86400 * 30
222 GROUP BY users.name, platform, users.creation_ts
223 ) u GROUP BY platform
224 """
225
226 results = {}
227 txn.execute(sql, (thirty_days_ago_in_secs, thirty_days_ago_in_secs))
228
229 for row in txn:
230 if row[0] == "unknown":
231 pass
232 results[row[0]] = row[1]
233
234 sql = """
235 SELECT COALESCE(count(*), 0) FROM (
236 SELECT users.name, users.creation_ts * 1000,
237 MAX(uip.last_seen)
238 FROM users
239 INNER JOIN (
240 SELECT
241 user_id,
242 last_seen
243 FROM user_ips
244 ) uip
245 ON users.name = uip.user_id
246 AND appservice_id is NULL
247 AND users.creation_ts < ?
248 AND uip.last_seen/1000 > ?
249 AND (uip.last_seen/1000) - users.creation_ts > 86400 * 30
250 GROUP BY users.name, users.creation_ts
251 ) u
252 """
253
254 txn.execute(sql, (thirty_days_ago_in_secs, thirty_days_ago_in_secs))
255
256 (count,) = txn.fetchone()
257 results["all"] = count
258
259 return results
260
261 return await self.db_pool.runInteraction("count_r30_users", _count_r30_users)
262
263 def _get_start_of_day(self):
264 """
265 Returns millisecond unixtime for start of UTC day.
266 """
267 now = time.gmtime()
268 today_start = calendar.timegm((now.tm_year, now.tm_mon, now.tm_mday, 0, 0, 0))
269 return today_start * 1000
270
271 @wrap_as_background_process("generate_user_daily_visits")
272 async def generate_user_daily_visits(self) -> None:
273 """
274 Generates daily visit data for use in cohort/ retention analysis
275 """
276
277 def _generate_user_daily_visits(txn):
278 logger.info("Calling _generate_user_daily_visits")
279 today_start = self._get_start_of_day()
280 a_day_in_milliseconds = 24 * 60 * 60 * 1000
281 now = self._clock.time_msec()
282
283 # A note on user_agent. Technically a given device can have multiple
284 # user agents, so we need to decide which one to pick. We could have
285 # handled this in number of ways, but given that we don't care
286 # _that_ much we have gone for MAX(). For more details of the other
287 # options considered see
288 # https://github.com/matrix-org/synapse/pull/8503#discussion_r502306111
289 sql = """
290 INSERT INTO user_daily_visits (user_id, device_id, timestamp, user_agent)
291 SELECT u.user_id, u.device_id, ?, MAX(u.user_agent)
292 FROM user_ips AS u
293 LEFT JOIN (
294 SELECT user_id, device_id, timestamp FROM user_daily_visits
295 WHERE timestamp = ?
296 ) udv
297 ON u.user_id = udv.user_id AND u.device_id=udv.device_id
298 INNER JOIN users ON users.name=u.user_id
299 WHERE last_seen > ? AND last_seen <= ?
300 AND udv.timestamp IS NULL AND users.is_guest=0
301 AND users.appservice_id IS NULL
302 GROUP BY u.user_id, u.device_id
303 """
304
305 # This means that the day has rolled over but there could still
306 # be entries from the previous day. There is an edge case
307 # where if the user logs in at 23:59 and overwrites their
308 # last_seen at 00:01 then they will not be counted in the
309 # previous day's stats - it is important that the query is run
310 # often to minimise this case.
311 if today_start > self._last_user_visit_update:
312 yesterday_start = today_start - a_day_in_milliseconds
313 txn.execute(
314 sql,
315 (
316 yesterday_start,
317 yesterday_start,
318 self._last_user_visit_update,
319 today_start,
320 ),
321 )
322 self._last_user_visit_update = today_start
323
324 txn.execute(
325 sql, (today_start, today_start, self._last_user_visit_update, now)
326 )
327 # Update _last_user_visit_update to now. The reason to do this
328 # rather just clamping to the beginning of the day is to limit
329 # the size of the join - meaning that the query can be run more
330 # frequently
331 self._last_user_visit_update = now
332
333 await self.db_pool.runInteraction(
334 "generate_user_daily_visits", _generate_user_daily_visits
335 )
1414 import logging
1515 from typing import Dict, List
1616
17 from synapse.metrics.background_process_metrics import wrap_as_background_process
1718 from synapse.storage._base import SQLBaseStore
1819 from synapse.storage.database import DatabasePool, make_in_list_sql_clause
1920 from synapse.util.caches.descriptors import cached
3031 super().__init__(database, db_conn, hs)
3132 self._clock = hs.get_clock()
3233 self.hs = hs
34
35 self._limit_usage_by_mau = hs.config.limit_usage_by_mau
36 self._max_mau_value = hs.config.max_mau_value
3337
3438 @cached(num_args=0)
3539 async def get_monthly_active_count(self) -> int:
123127 desc="user_last_seen_monthly_active",
124128 )
125129
126
127 class MonthlyActiveUsersStore(MonthlyActiveUsersWorkerStore):
128 def __init__(self, database: DatabasePool, db_conn, hs):
129 super().__init__(database, db_conn, hs)
130
131 self._limit_usage_by_mau = hs.config.limit_usage_by_mau
132 self._mau_stats_only = hs.config.mau_stats_only
133 self._max_mau_value = hs.config.max_mau_value
134
135 # Do not add more reserved users than the total allowable number
136 # cur = LoggingTransaction(
137 self.db_pool.new_transaction(
138 db_conn,
139 "initialise_mau_threepids",
140 [],
141 [],
142 self._initialise_reserved_users,
143 hs.config.mau_limits_reserved_threepids[: self._max_mau_value],
144 )
145
146 def _initialise_reserved_users(self, txn, threepids):
147 """Ensures that reserved threepids are accounted for in the MAU table, should
148 be called on start up.
149
150 Args:
151 txn (cursor):
152 threepids (list[dict]): List of threepid dicts to reserve
153 """
154
155 # XXX what is this function trying to achieve? It upserts into
156 # monthly_active_users for each *registered* reserved mau user, but why?
157 #
158 # - shouldn't there already be an entry for each reserved user (at least
159 # if they have been active recently)?
160 #
161 # - if it's important that the timestamp is kept up to date, why do we only
162 # run this at startup?
163
164 for tp in threepids:
165 user_id = self.get_user_id_by_threepid_txn(txn, tp["medium"], tp["address"])
166
167 if user_id:
168 is_support = self.is_support_user_txn(txn, user_id)
169 if not is_support:
170 # We do this manually here to avoid hitting #6791
171 self.db_pool.simple_upsert_txn(
172 txn,
173 table="monthly_active_users",
174 keyvalues={"user_id": user_id},
175 values={"timestamp": int(self._clock.time_msec())},
176 )
177 else:
178 logger.warning("mau limit reserved threepid %s not found in db" % tp)
179
130 @wrap_as_background_process("reap_monthly_active_users")
180131 async def reap_monthly_active_users(self):
181132 """Cleans out monthly active user table to ensure that no stale
182133 entries exist.
255206 await self.db_pool.runInteraction(
256207 "reap_monthly_active_users", _reap_users, reserved_users
257208 )
209
210
211 class MonthlyActiveUsersStore(MonthlyActiveUsersWorkerStore):
212 def __init__(self, database: DatabasePool, db_conn, hs):
213 super().__init__(database, db_conn, hs)
214
215 self._mau_stats_only = hs.config.mau_stats_only
216
217 # Do not add more reserved users than the total allowable number
218 self.db_pool.new_transaction(
219 db_conn,
220 "initialise_mau_threepids",
221 [],
222 [],
223 self._initialise_reserved_users,
224 hs.config.mau_limits_reserved_threepids[: self._max_mau_value],
225 )
226
227 def _initialise_reserved_users(self, txn, threepids):
228 """Ensures that reserved threepids are accounted for in the MAU table, should
229 be called on start up.
230
231 Args:
232 txn (cursor):
233 threepids (list[dict]): List of threepid dicts to reserve
234 """
235
236 # XXX what is this function trying to achieve? It upserts into
237 # monthly_active_users for each *registered* reserved mau user, but why?
238 #
239 # - shouldn't there already be an entry for each reserved user (at least
240 # if they have been active recently)?
241 #
242 # - if it's important that the timestamp is kept up to date, why do we only
243 # run this at startup?
244
245 for tp in threepids:
246 user_id = self.get_user_id_by_threepid_txn(txn, tp["medium"], tp["address"])
247
248 if user_id:
249 is_support = self.is_support_user_txn(txn, user_id)
250 if not is_support:
251 # We do this manually here to avoid hitting #6791
252 self.db_pool.simple_upsert_txn(
253 txn,
254 table="monthly_active_users",
255 keyvalues={"user_id": user_id},
256 values={"timestamp": int(self._clock.time_msec())},
257 )
258 else:
259 logger.warning("mau limit reserved threepid %s not found in db" % tp)
258260
259261 async def upsert_monthly_active_user(self, user_id: str) -> None:
260262 """Updates or inserts the user into the monthly active user table, which
1111 # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
1212 # See the License for the specific language governing permissions and
1313 # limitations under the License.
14 from typing import Any, Dict, Optional
14 from typing import Any, Dict, List, Optional
1515
1616 from synapse.api.errors import StoreError
1717 from synapse.storage._base import SQLBaseStore
7171 )
7272
7373 async def set_profile_displayname(
74 self, user_localpart: str, new_displayname: str
74 self, user_localpart: str, new_displayname: Optional[str]
7575 ) -> None:
7676 await self.db_pool.simple_update_one(
7777 table="profiles",
8888 keyvalues={"user_id": user_localpart},
8989 updatevalues={"avatar_url": new_avatar_url},
9090 desc="set_profile_avatar_url",
91 )
92
93
94 class ProfileStore(ProfileWorkerStore):
95 async def add_remote_profile_cache(
96 self, user_id: str, displayname: str, avatar_url: str
97 ) -> None:
98 """Ensure we are caching the remote user's profiles.
99
100 This should only be called when `is_subscribed_remote_profile_for_user`
101 would return true for the user.
102 """
103 await self.db_pool.simple_upsert(
104 table="remote_profile_cache",
105 keyvalues={"user_id": user_id},
106 values={
107 "displayname": displayname,
108 "avatar_url": avatar_url,
109 "last_check": self._clock.time_msec(),
110 },
111 desc="add_remote_profile_cache",
11291 )
11392
11493 async def update_remote_profile_cache(
137116 desc="delete_remote_profile_cache",
138117 )
139118
140 async def get_remote_profile_cache_entries_that_expire(
141 self, last_checked: int
142 ) -> Dict[str, str]:
143 """Get all users who haven't been checked since `last_checked`
144 """
145
146 def _get_remote_profile_cache_entries_that_expire_txn(txn):
147 sql = """
148 SELECT user_id, displayname, avatar_url
149 FROM remote_profile_cache
150 WHERE last_check < ?
151 """
152
153 txn.execute(sql, (last_checked,))
154
155 return self.db_pool.cursor_to_dict(txn)
156
157 return await self.db_pool.runInteraction(
158 "get_remote_profile_cache_entries_that_expire",
159 _get_remote_profile_cache_entries_that_expire_txn,
160 )
161
162119 async def is_subscribed_remote_profile_for_user(self, user_id):
163120 """Check whether we are interested in a remote user's profile.
164121 """
183140
184141 if res:
185142 return True
143
144 async def get_remote_profile_cache_entries_that_expire(
145 self, last_checked: int
146 ) -> List[Dict[str, str]]:
147 """Get all users who haven't been checked since `last_checked`
148 """
149
150 def _get_remote_profile_cache_entries_that_expire_txn(txn):
151 sql = """
152 SELECT user_id, displayname, avatar_url
153 FROM remote_profile_cache
154 WHERE last_check < ?
155 """
156
157 txn.execute(sql, (last_checked,))
158
159 return self.db_pool.cursor_to_dict(txn)
160
161 return await self.db_pool.runInteraction(
162 "get_remote_profile_cache_entries_that_expire",
163 _get_remote_profile_cache_entries_that_expire_txn,
164 )
165
166
167 class ProfileStore(ProfileWorkerStore):
168 async def add_remote_profile_cache(
169 self, user_id: str, displayname: str, avatar_url: str
170 ) -> None:
171 """Ensure we are caching the remote user's profiles.
172
173 This should only be called when `is_subscribed_remote_profile_for_user`
174 would return true for the user.
175 """
176 await self.db_pool.simple_upsert(
177 table="remote_profile_cache",
178 keyvalues={"user_id": user_id},
179 values={
180 "displayname": displayname,
181 "avatar_url": avatar_url,
182 "last_check": self._clock.time_msec(),
183 },
184 desc="add_remote_profile_cache",
185 )
302302 lock=False,
303303 )
304304
305 user_has_pusher = self.get_if_user_has_pusher.cache.get(
305 user_has_pusher = self.get_if_user_has_pusher.cache.get_immediate(
306306 (user_id,), None, update_metrics=False
307307 )
308308
2222 from synapse.storage._base import SQLBaseStore, db_to_json, make_in_list_sql_clause
2323 from synapse.storage.database import DatabasePool
2424 from synapse.storage.util.id_generators import StreamIdGenerator
25 from synapse.types import JsonDict
2526 from synapse.util import json_encoder
26 from synapse.util.async_helpers import ObservableDeferred
2727 from synapse.util.caches.descriptors import cached, cachedList
2828 from synapse.util.caches.stream_change_cache import StreamChangeCache
2929
273273 }
274274 return results
275275
276 @cached(num_args=2,)
277 async def get_linearized_receipts_for_all_rooms(
278 self, to_key: int, from_key: Optional[int] = None
279 ) -> Dict[str, JsonDict]:
280 """Get receipts for all rooms between two stream_ids.
281
282 Args:
283 to_key: Max stream id to fetch receipts upto.
284 from_key: Min stream id to fetch receipts from. None fetches
285 from the start.
286
287 Returns:
288 A dictionary of roomids to a list of receipts.
289 """
290
291 def f(txn):
292 if from_key:
293 sql = """
294 SELECT * FROM receipts_linearized WHERE
295 stream_id > ? AND stream_id <= ?
296 """
297 txn.execute(sql, [from_key, to_key])
298 else:
299 sql = """
300 SELECT * FROM receipts_linearized WHERE
301 stream_id <= ?
302 """
303
304 txn.execute(sql, [to_key])
305
306 return self.db_pool.cursor_to_dict(txn)
307
308 txn_results = await self.db_pool.runInteraction(
309 "get_linearized_receipts_for_all_rooms", f
310 )
311
312 results = {}
313 for row in txn_results:
314 # We want a single event per room, since we want to batch the
315 # receipts by room, event and type.
316 room_event = results.setdefault(
317 row["room_id"],
318 {"type": "m.receipt", "room_id": row["room_id"], "content": {}},
319 )
320
321 # The content is of the form:
322 # {"$foo:bar": { "read": { "@user:host": <receipt> }, .. }, .. }
323 event_entry = room_event["content"].setdefault(row["event_id"], {})
324 receipt_type = event_entry.setdefault(row["receipt_type"], {})
325
326 receipt_type[row["user_id"]] = db_to_json(row["data"])
327
328 return results
329
276330 async def get_users_sent_receipts_between(
277331 self, last_id: int, current_id: int
278332 ) -> List[str]:
357411 if receipt_type != "m.read":
358412 return
359413
360 # Returns either an ObservableDeferred or the raw result
361 res = self.get_users_with_read_receipts_in_room.cache.get(
414 res = self.get_users_with_read_receipts_in_room.cache.get_immediate(
362415 room_id, None, update_metrics=False
363416 )
364
365 # first handle the ObservableDeferred case
366 if isinstance(res, ObservableDeferred):
367 if res.has_called():
368 res = res.get_result()
369 else:
370 res = None
371417
372418 if res and user_id in res:
373419 # We'd only be adding to the set, so no point invalidating if the
00 # -*- coding: utf-8 -*-
11 # Copyright 2014-2016 OpenMarket Ltd
22 # Copyright 2017-2018 New Vector Ltd
3 # Copyright 2019 The Matrix.org Foundation C.I.C.
3 # Copyright 2019,2020 The Matrix.org Foundation C.I.C.
44 #
55 # Licensed under the Apache License, Version 2.0 (the "License");
66 # you may not use this file except in compliance with the License.
1313 # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
1414 # See the License for the specific language governing permissions and
1515 # limitations under the License.
16
1716 import logging
1817 import re
1918 from typing import Any, Dict, List, Optional, Tuple
2019
2120 from synapse.api.constants import UserTypes
2221 from synapse.api.errors import Codes, StoreError, SynapseError, ThreepidValidationError
23 from synapse.metrics.background_process_metrics import run_as_background_process
22 from synapse.metrics.background_process_metrics import wrap_as_background_process
2423 from synapse.storage._base import SQLBaseStore
2524 from synapse.storage.database import DatabasePool
2625 from synapse.storage.types import Cursor
4645 self._user_id_seq = build_sequence_generator(
4746 database.engine, find_max_generated_user_id_localpart, "user_id_seq",
4847 )
48
49 self._account_validity = hs.config.account_validity
50 if hs.config.run_background_tasks and self._account_validity.enabled:
51 self._clock.call_later(
52 0.0, self._set_expiration_date_when_missing,
53 )
54
55 # Create a background job for culling expired 3PID validity tokens
56 if hs.config.run_background_tasks:
57 self.clock.looping_call(
58 self.cull_expired_threepid_validation_tokens, THIRTY_MINUTES_IN_MS
59 )
4960
5061 @cached()
5162 async def get_user_by_id(self, user_id: str) -> Optional[Dict[str, Any]]:
777788 "delete_threepid_session", delete_threepid_session_txn
778789 )
779790
791 @wrap_as_background_process("cull_expired_threepid_validation_tokens")
792 async def cull_expired_threepid_validation_tokens(self) -> None:
793 """Remove threepid validation tokens with expiry dates that have passed"""
794
795 def cull_expired_threepid_validation_tokens_txn(txn, ts):
796 sql = """
797 DELETE FROM threepid_validation_token WHERE
798 expires < ?
799 """
800 txn.execute(sql, (ts,))
801
802 await self.db_pool.runInteraction(
803 "cull_expired_threepid_validation_tokens",
804 cull_expired_threepid_validation_tokens_txn,
805 self.clock.time_msec(),
806 )
807
808 @wrap_as_background_process("account_validity_set_expiration_dates")
809 async def _set_expiration_date_when_missing(self):
810 """
811 Retrieves the list of registered users that don't have an expiration date, and
812 adds an expiration date for each of them.
813 """
814
815 def select_users_with_no_expiration_date_txn(txn):
816 """Retrieves the list of registered users with no expiration date from the
817 database, filtering out deactivated users.
818 """
819 sql = (
820 "SELECT users.name FROM users"
821 " LEFT JOIN account_validity ON (users.name = account_validity.user_id)"
822 " WHERE account_validity.user_id is NULL AND users.deactivated = 0;"
823 )
824 txn.execute(sql, [])
825
826 res = self.db_pool.cursor_to_dict(txn)
827 if res:
828 for user in res:
829 self.set_expiration_date_for_user_txn(
830 txn, user["name"], use_delta=True
831 )
832
833 await self.db_pool.runInteraction(
834 "get_users_with_no_expiration_date",
835 select_users_with_no_expiration_date_txn,
836 )
837
838 def set_expiration_date_for_user_txn(self, txn, user_id, use_delta=False):
839 """Sets an expiration date to the account with the given user ID.
840
841 Args:
842 user_id (str): User ID to set an expiration date for.
843 use_delta (bool): If set to False, the expiration date for the user will be
844 now + validity period. If set to True, this expiration date will be a
845 random value in the [now + period - d ; now + period] range, d being a
846 delta equal to 10% of the validity period.
847 """
848 now_ms = self._clock.time_msec()
849 expiration_ts = now_ms + self._account_validity.period
850
851 if use_delta:
852 expiration_ts = self.rand.randrange(
853 expiration_ts - self._account_validity.startup_job_max_delta,
854 expiration_ts,
855 )
856
857 self.db_pool.simple_upsert_txn(
858 txn,
859 "account_validity",
860 keyvalues={"user_id": user_id},
861 values={"expiration_ts_ms": expiration_ts, "email_sent": False},
862 )
863
864 async def get_user_pending_deactivation(self) -> Optional[str]:
865 """
866 Gets one user from the table of users waiting to be parted from all the rooms
867 they're in.
868 """
869 return await self.db_pool.simple_select_one_onecol(
870 "users_pending_deactivation",
871 keyvalues={},
872 retcol="user_id",
873 allow_none=True,
874 desc="get_users_pending_deactivation",
875 )
876
877 async def del_user_pending_deactivation(self, user_id: str) -> None:
878 """
879 Removes the given user to the table of users who need to be parted from all the
880 rooms they're in, effectively marking that user as fully deactivated.
881 """
882 # XXX: This should be simple_delete_one but we failed to put a unique index on
883 # the table, so somehow duplicate entries have ended up in it.
884 await self.db_pool.simple_delete(
885 "users_pending_deactivation",
886 keyvalues={"user_id": user_id},
887 desc="del_user_pending_deactivation",
888 )
889
780890
781891 class RegistrationBackgroundUpdateStore(RegistrationWorkerStore):
782892 def __init__(self, database: DatabasePool, db_conn, hs):
9101020 def __init__(self, database: DatabasePool, db_conn, hs):
9111021 super().__init__(database, db_conn, hs)
9121022
913 self._account_validity = hs.config.account_validity
9141023 self._ignore_unknown_session_error = hs.config.request_token_inhibit_3pid_errors
915
916 if self._account_validity.enabled:
917 self._clock.call_later(
918 0.0,
919 run_as_background_process,
920 "account_validity_set_expiration_dates",
921 self._set_expiration_date_when_missing,
922 )
923
924 # Create a background job for culling expired 3PID validity tokens
925 def start_cull():
926 # run as a background process to make sure that the database transactions
927 # have a logcontext to report to
928 return run_as_background_process(
929 "cull_expired_threepid_validation_tokens",
930 self.cull_expired_threepid_validation_tokens,
931 )
932
933 hs.get_clock().looping_call(start_cull, THIRTY_MINUTES_IN_MS)
9341024
9351025 async def add_access_token_to_user(
9361026 self,
9381028 token: str,
9391029 device_id: Optional[str],
9401030 valid_until_ms: Optional[int],
941 ) -> None:
1031 ) -> int:
9421032 """Adds an access token for the given user.
9431033
9441034 Args:
9481038 valid_until_ms: when the token is valid until. None for no expiry.
9491039 Raises:
9501040 StoreError if there was a problem adding this.
1041 Returns:
1042 The token ID
9511043 """
9521044 next_id = self._access_tokens_id_gen.get_next()
9531045
9611053 "valid_until_ms": valid_until_ms,
9621054 },
9631055 desc="add_access_token_to_user",
1056 )
1057
1058 return next_id
1059
1060 def _set_device_for_access_token_txn(self, txn, token: str, device_id: str) -> str:
1061 old_device_id = self.db_pool.simple_select_one_onecol_txn(
1062 txn, "access_tokens", {"token": token}, "device_id"
1063 )
1064
1065 self.db_pool.simple_update_txn(
1066 txn, "access_tokens", {"token": token}, {"device_id": device_id}
1067 )
1068
1069 self._invalidate_cache_and_stream(txn, self.get_user_by_access_token, (token,))
1070
1071 return old_device_id
1072
1073 async def set_device_for_access_token(self, token: str, device_id: str) -> str:
1074 """Sets the device ID associated with an access token.
1075
1076 Args:
1077 token: The access token to modify.
1078 device_id: The new device ID.
1079 Returns:
1080 The old device ID associated with the access token.
1081 """
1082
1083 return await self.db_pool.runInteraction(
1084 "set_device_for_access_token",
1085 self._set_device_for_access_token_txn,
1086 token,
1087 device_id,
9641088 )
9651089
9661090 async def register_user(
11201244 desc="record_user_external_id",
11211245 )
11221246
1123 async def user_set_password_hash(self, user_id: str, password_hash: str) -> None:
1247 async def user_set_password_hash(
1248 self, user_id: str, password_hash: Optional[str]
1249 ) -> None:
11241250 """
11251251 NB. This does *not* evict any cache because the one use for this
11261252 removes most of the entries subsequently anyway so it would be
12681394 "users_pending_deactivation",
12691395 values={"user_id": user_id},
12701396 desc="add_user_pending_deactivation",
1271 )
1272
1273 async def del_user_pending_deactivation(self, user_id: str) -> None:
1274 """
1275 Removes the given user to the table of users who need to be parted from all the
1276 rooms they're in, effectively marking that user as fully deactivated.
1277 """
1278 # XXX: This should be simple_delete_one but we failed to put a unique index on
1279 # the table, so somehow duplicate entries have ended up in it.
1280 await self.db_pool.simple_delete(
1281 "users_pending_deactivation",
1282 keyvalues={"user_id": user_id},
1283 desc="del_user_pending_deactivation",
1284 )
1285
1286 async def get_user_pending_deactivation(self) -> Optional[str]:
1287 """
1288 Gets one user from the table of users waiting to be parted from all the rooms
1289 they're in.
1290 """
1291 return await self.db_pool.simple_select_one_onecol(
1292 "users_pending_deactivation",
1293 keyvalues={},
1294 retcol="user_id",
1295 allow_none=True,
1296 desc="get_users_pending_deactivation",
12971397 )
12981398
12991399 async def validate_threepid_session(
14461546 start_or_continue_validation_session_txn,
14471547 )
14481548
1449 async def cull_expired_threepid_validation_tokens(self) -> None:
1450 """Remove threepid validation tokens with expiry dates that have passed"""
1451
1452 def cull_expired_threepid_validation_tokens_txn(txn, ts):
1453 sql = """
1454 DELETE FROM threepid_validation_token WHERE
1455 expires < ?
1456 """
1457 txn.execute(sql, (ts,))
1458
1459 await self.db_pool.runInteraction(
1460 "cull_expired_threepid_validation_tokens",
1461 cull_expired_threepid_validation_tokens_txn,
1462 self.clock.time_msec(),
1463 )
1464
14651549 async def set_user_deactivated_status(
14661550 self, user_id: str, deactivated: bool
14671551 ) -> None:
14911575 )
14921576 txn.call_after(self.is_guest.invalidate, (user_id,))
14931577
1494 async def _set_expiration_date_when_missing(self):
1495 """
1496 Retrieves the list of registered users that don't have an expiration date, and
1497 adds an expiration date for each of them.
1498 """
1499
1500 def select_users_with_no_expiration_date_txn(txn):
1501 """Retrieves the list of registered users with no expiration date from the
1502 database, filtering out deactivated users.
1503 """
1504 sql = (
1505 "SELECT users.name FROM users"
1506 " LEFT JOIN account_validity ON (users.name = account_validity.user_id)"
1507 " WHERE account_validity.user_id is NULL AND users.deactivated = 0;"
1508 )
1509 txn.execute(sql, [])
1510
1511 res = self.db_pool.cursor_to_dict(txn)
1512 if res:
1513 for user in res:
1514 self.set_expiration_date_for_user_txn(
1515 txn, user["name"], use_delta=True
1516 )
1517
1518 await self.db_pool.runInteraction(
1519 "get_users_with_no_expiration_date",
1520 select_users_with_no_expiration_date_txn,
1521 )
1522
1523 def set_expiration_date_for_user_txn(self, txn, user_id, use_delta=False):
1524 """Sets an expiration date to the account with the given user ID.
1525
1526 Args:
1527 user_id (str): User ID to set an expiration date for.
1528 use_delta (bool): If set to False, the expiration date for the user will be
1529 now + validity period. If set to True, this expiration date will be a
1530 random value in the [now + period - d ; now + period] range, d being a
1531 delta equal to 10% of the validity period.
1532 """
1533 now_ms = self._clock.time_msec()
1534 expiration_ts = now_ms + self._account_validity.period
1535
1536 if use_delta:
1537 expiration_ts = self.rand.randrange(
1538 expiration_ts - self._account_validity.startup_job_max_delta,
1539 expiration_ts,
1540 )
1541
1542 self.db_pool.simple_upsert_txn(
1543 txn,
1544 "account_validity",
1545 keyvalues={"user_id": user_id},
1546 values={"expiration_ts_ms": expiration_ts, "email_sent": False},
1547 )
1548
15491578
15501579 def find_max_generated_user_id_localpart(cur: Cursor) -> int:
15511580 """
190190 return await self.db_pool.runInteraction(
191191 "count_public_rooms", _count_public_rooms_txn
192192 )
193
194 async def get_room_count(self) -> int:
195 """Retrieve the total number of rooms.
196 """
197
198 def f(txn):
199 sql = "SELECT count(*) FROM rooms"
200 txn.execute(sql)
201 row = txn.fetchone()
202 return row[0] or 0
203
204 return await self.db_pool.runInteraction("get_rooms", f)
193205
194206 async def get_largest_public_rooms(
195207 self,
856868 "get_all_new_public_rooms", get_all_new_public_rooms
857869 )
858870
871 async def get_rooms_for_retention_period_in_range(
872 self, min_ms: Optional[int], max_ms: Optional[int], include_null: bool = False
873 ) -> Dict[str, dict]:
874 """Retrieves all of the rooms within the given retention range.
875
876 Optionally includes the rooms which don't have a retention policy.
877
878 Args:
879 min_ms: Duration in milliseconds that define the lower limit of
880 the range to handle (exclusive). If None, doesn't set a lower limit.
881 max_ms: Duration in milliseconds that define the upper limit of
882 the range to handle (inclusive). If None, doesn't set an upper limit.
883 include_null: Whether to include rooms which retention policy is NULL
884 in the returned set.
885
886 Returns:
887 The rooms within this range, along with their retention
888 policy. The key is "room_id", and maps to a dict describing the retention
889 policy associated with this room ID. The keys for this nested dict are
890 "min_lifetime" (int|None), and "max_lifetime" (int|None).
891 """
892
893 def get_rooms_for_retention_period_in_range_txn(txn):
894 range_conditions = []
895 args = []
896
897 if min_ms is not None:
898 range_conditions.append("max_lifetime > ?")
899 args.append(min_ms)
900
901 if max_ms is not None:
902 range_conditions.append("max_lifetime <= ?")
903 args.append(max_ms)
904
905 # Do a first query which will retrieve the rooms that have a retention policy
906 # in their current state.
907 sql = """
908 SELECT room_id, min_lifetime, max_lifetime FROM room_retention
909 INNER JOIN current_state_events USING (event_id, room_id)
910 """
911
912 if len(range_conditions):
913 sql += " WHERE (" + " AND ".join(range_conditions) + ")"
914
915 if include_null:
916 sql += " OR max_lifetime IS NULL"
917
918 txn.execute(sql, args)
919
920 rows = self.db_pool.cursor_to_dict(txn)
921 rooms_dict = {}
922
923 for row in rows:
924 rooms_dict[row["room_id"]] = {
925 "min_lifetime": row["min_lifetime"],
926 "max_lifetime": row["max_lifetime"],
927 }
928
929 if include_null:
930 # If required, do a second query that retrieves all of the rooms we know
931 # of so we can handle rooms with no retention policy.
932 sql = "SELECT DISTINCT room_id FROM current_state_events"
933
934 txn.execute(sql)
935
936 rows = self.db_pool.cursor_to_dict(txn)
937
938 # If a room isn't already in the dict (i.e. it doesn't have a retention
939 # policy in its state), add it with a null policy.
940 for row in rows:
941 if row["room_id"] not in rooms_dict:
942 rooms_dict[row["room_id"]] = {
943 "min_lifetime": None,
944 "max_lifetime": None,
945 }
946
947 return rooms_dict
948
949 return await self.db_pool.runInteraction(
950 "get_rooms_for_retention_period_in_range",
951 get_rooms_for_retention_period_in_range_txn,
952 )
953
859954
860955 class RoomBackgroundUpdateStore(SQLBaseStore):
861956 REMOVE_TOMESTONED_ROOMS_BG_UPDATE = "remove_tombstoned_rooms_from_directory"
12911386 )
12921387 self.hs.get_notifier().on_new_replication_data()
12931388
1294 async def get_room_count(self) -> int:
1295 """Retrieve the total number of rooms.
1296 """
1297
1298 def f(txn):
1299 sql = "SELECT count(*) FROM rooms"
1300 txn.execute(sql)
1301 row = txn.fetchone()
1302 return row[0] or 0
1303
1304 return await self.db_pool.runInteraction("get_rooms", f)
1305
13061389 async def add_event_report(
13071390 self,
13081391 room_id: str,
14451528 self.is_room_blocked,
14461529 (room_id,),
14471530 )
1448
1449 async def get_rooms_for_retention_period_in_range(
1450 self, min_ms: Optional[int], max_ms: Optional[int], include_null: bool = False
1451 ) -> Dict[str, dict]:
1452 """Retrieves all of the rooms within the given retention range.
1453
1454 Optionally includes the rooms which don't have a retention policy.
1455
1456 Args:
1457 min_ms: Duration in milliseconds that define the lower limit of
1458 the range to handle (exclusive). If None, doesn't set a lower limit.
1459 max_ms: Duration in milliseconds that define the upper limit of
1460 the range to handle (inclusive). If None, doesn't set an upper limit.
1461 include_null: Whether to include rooms which retention policy is NULL
1462 in the returned set.
1463
1464 Returns:
1465 The rooms within this range, along with their retention
1466 policy. The key is "room_id", and maps to a dict describing the retention
1467 policy associated with this room ID. The keys for this nested dict are
1468 "min_lifetime" (int|None), and "max_lifetime" (int|None).
1469 """
1470
1471 def get_rooms_for_retention_period_in_range_txn(txn):
1472 range_conditions = []
1473 args = []
1474
1475 if min_ms is not None:
1476 range_conditions.append("max_lifetime > ?")
1477 args.append(min_ms)
1478
1479 if max_ms is not None:
1480 range_conditions.append("max_lifetime <= ?")
1481 args.append(max_ms)
1482
1483 # Do a first query which will retrieve the rooms that have a retention policy
1484 # in their current state.
1485 sql = """
1486 SELECT room_id, min_lifetime, max_lifetime FROM room_retention
1487 INNER JOIN current_state_events USING (event_id, room_id)
1488 """
1489
1490 if len(range_conditions):
1491 sql += " WHERE (" + " AND ".join(range_conditions) + ")"
1492
1493 if include_null:
1494 sql += " OR max_lifetime IS NULL"
1495
1496 txn.execute(sql, args)
1497
1498 rows = self.db_pool.cursor_to_dict(txn)
1499 rooms_dict = {}
1500
1501 for row in rows:
1502 rooms_dict[row["room_id"]] = {
1503 "min_lifetime": row["min_lifetime"],
1504 "max_lifetime": row["max_lifetime"],
1505 }
1506
1507 if include_null:
1508 # If required, do a second query that retrieves all of the rooms we know
1509 # of so we can handle rooms with no retention policy.
1510 sql = "SELECT DISTINCT room_id FROM current_state_events"
1511
1512 txn.execute(sql)
1513
1514 rows = self.db_pool.cursor_to_dict(txn)
1515
1516 # If a room isn't already in the dict (i.e. it doesn't have a retention
1517 # policy in its state), add it with a null policy.
1518 for row in rows:
1519 if row["room_id"] not in rooms_dict:
1520 rooms_dict[row["room_id"]] = {
1521 "min_lifetime": None,
1522 "max_lifetime": None,
1523 }
1524
1525 return rooms_dict
1526
1527 rooms = await self.db_pool.runInteraction(
1528 "get_rooms_for_retention_period_in_range",
1529 get_rooms_for_retention_period_in_range_txn,
1530 )
1531
1532 return rooms
1919 from synapse.events import EventBase
2020 from synapse.events.snapshot import EventContext
2121 from synapse.metrics import LaterGauge
22 from synapse.metrics.background_process_metrics import run_as_background_process
23 from synapse.storage._base import (
24 LoggingTransaction,
25 SQLBaseStore,
26 db_to_json,
27 make_in_list_sql_clause,
22 from synapse.metrics.background_process_metrics import (
23 run_as_background_process,
24 wrap_as_background_process,
2825 )
26 from synapse.storage._base import SQLBaseStore, db_to_json, make_in_list_sql_clause
2927 from synapse.storage.database import DatabasePool
3028 from synapse.storage.databases.main.events_worker import EventsWorkerStore
3129 from synapse.storage.engines import Sqlite3Engine
5957 # background update still running?
6058 self._current_state_events_membership_up_to_date = False
6159
62 txn = LoggingTransaction(
63 db_conn.cursor(),
64 name="_check_safe_current_state_events_membership_updated",
65 database_engine=self.database_engine,
60 txn = db_conn.cursor(
61 txn_name="_check_safe_current_state_events_membership_updated"
6662 )
6763 self._check_safe_current_state_events_membership_updated_txn(txn)
6864 txn.close()
6965
70 if self.hs.config.metrics_flags.known_servers:
66 if (
67 self.hs.config.run_background_tasks
68 and self.hs.config.metrics_flags.known_servers
69 ):
7170 self._known_servers_count = 1
7271 self.hs.get_clock().looping_call(
73 run_as_background_process,
74 60 * 1000,
75 "_count_known_servers",
76 self._count_known_servers,
72 self._count_known_servers, 60 * 1000,
7773 )
7874 self.hs.get_clock().call_later(
79 1000,
80 run_as_background_process,
81 "_count_known_servers",
82 self._count_known_servers,
75 1000, self._count_known_servers,
8376 )
8477 LaterGauge(
8578 "synapse_federation_known_servers",
8881 lambda: self._known_servers_count,
8982 )
9083
84 @wrap_as_background_process("_count_known_servers")
9185 async def _count_known_servers(self):
9286 """
9387 Count the servers that this server knows about.
534528 # If we do then we can reuse that result and simply update it with
535529 # any membership changes in `delta_ids`
536530 if context.prev_group and context.delta_ids:
537 prev_res = self._get_joined_users_from_context.cache.get(
531 prev_res = self._get_joined_users_from_context.cache.get_immediate(
538532 (room_id, context.prev_group), None
539533 )
540534 if prev_res and isinstance(prev_res, dict):
6565 row[8] = bytes(row[8]).decode("utf-8")
6666 row[11] = bytes(row[11]).decode("utf-8")
6767 cur.execute(
68 database_engine.convert_param_style(
69 """
70 INSERT into pushers2 (
71 id, user_name, access_token, profile_tag, kind,
72 app_id, app_display_name, device_display_name,
73 pushkey, ts, lang, data, last_token, last_success,
74 failing_since
75 ) values (%s)"""
76 % (",".join(["?" for _ in range(len(row))]))
77 ),
68 """
69 INSERT into pushers2 (
70 id, user_name, access_token, profile_tag, kind,
71 app_id, app_display_name, device_display_name,
72 pushkey, ts, lang, data, last_token, last_success,
73 failing_since
74 ) values (%s)
75 """
76 % (",".join(["?" for _ in range(len(row))])),
7877 row,
7978 )
8079 count += 1
7070 " VALUES (?, ?)"
7171 )
7272
73 sql = database_engine.convert_param_style(sql)
74
7573 cur.execute(sql, ("event_search", progress_json))
7674
7775
4949 " VALUES (?, ?)"
5050 )
5151
52 sql = database_engine.convert_param_style(sql)
53
5452 cur.execute(sql, ("event_origin_server_ts", progress_json))
5553
5654
5858 user_chunks = (user_ids[i : i + 100] for i in range(0, len(user_ids), n))
5959 for chunk in user_chunks:
6060 cur.execute(
61 database_engine.convert_param_style(
62 "UPDATE users SET appservice_id = ? WHERE name IN (%s)"
63 % (",".join("?" for _ in chunk),)
64 ),
61 "UPDATE users SET appservice_id = ? WHERE name IN (%s)"
62 % (",".join("?" for _ in chunk),),
6563 [as_id] + chunk,
6664 )
6464 row = list(row)
6565 row[12] = token_to_stream_ordering(row[12])
6666 cur.execute(
67 database_engine.convert_param_style(
68 """
69 INSERT into pushers2 (
70 id, user_name, access_token, profile_tag, kind,
71 app_id, app_display_name, device_display_name,
72 pushkey, ts, lang, data, last_stream_ordering, last_success,
73 failing_since
74 ) values (%s)"""
75 % (",".join(["?" for _ in range(len(row))]))
76 ),
67 """
68 INSERT into pushers2 (
69 id, user_name, access_token, profile_tag, kind,
70 app_id, app_display_name, device_display_name,
71 pushkey, ts, lang, data, last_stream_ordering, last_success,
72 failing_since
73 ) values (%s)
74 """
75 % (",".join(["?" for _ in range(len(row))])),
7776 row,
7877 )
7978 count += 1
5454 " VALUES (?, ?)"
5555 )
5656
57 sql = database_engine.convert_param_style(sql)
58
5957 cur.execute(sql, ("event_search_order", progress_json))
6058
6159
4949 " VALUES (?, ?)"
5050 )
5151
52 sql = database_engine.convert_param_style(sql)
53
5452 cur.execute(sql, ("event_fields_sender_url", progress_json))
5553
5654
2222
2323 def run_upgrade(cur, database_engine, *args, **kwargs):
2424 cur.execute(
25 database_engine.convert_param_style(
26 "UPDATE remote_media_cache SET last_access_ts = ?"
27 ),
28 (int(time.time() * 1000),),
25 "UPDATE remote_media_cache SET last_access_ts = ?", (int(time.time() * 1000),),
2926 )
00 import logging
1 from io import StringIO
12
23 from synapse.storage.engines import PostgresEngine
4 from synapse.storage.prepare_database import execute_statements_from_stream
35
46 logger = logging.getLogger(__name__)
57
4547 select_clause,
4648 )
4749
48 if isinstance(database_engine, PostgresEngine):
49 cur.execute(sql)
50 else:
51 cur.executescript(sql)
50 execute_statements_from_stream(cur, StringIO(sql))
6767 INNER JOIN room_memberships AS r USING (event_id)
6868 WHERE type = 'm.room.member' AND state_key LIKE ?
6969 """
70 sql = database_engine.convert_param_style(sql)
7170 cur.execute(sql, ("%:" + config.server_name,))
7271
7372 cur.execute(
0 /* Copyright 2020 The Matrix.org Foundation C.I.C
1 *
2 * Licensed under the Apache License, Version 2.0 (the "License");
3 * you may not use this file except in compliance with the License.
4 * You may obtain a copy of the License at
5 *
6 * http://www.apache.org/licenses/LICENSE-2.0
7 *
8 * Unless required by applicable law or agreed to in writing, software
9 * distributed under the License is distributed on an "AS IS" BASIS,
10 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
11 * See the License for the specific language governing permissions and
12 * limitations under the License.
13 */
14
15 CREATE TABLE IF NOT EXISTS dehydrated_devices(
16 user_id TEXT NOT NULL PRIMARY KEY,
17 device_id TEXT NOT NULL,
18 device_data TEXT NOT NULL -- JSON-encoded client-defined data
19 );
0 /* Copyright 2020 The Matrix.org Foundation C.I.C
1 *
2 * Licensed under the Apache License, Version 2.0 (the "License");
3 * you may not use this file except in compliance with the License.
4 * You may obtain a copy of the License at
5 *
6 * http://www.apache.org/licenses/LICENSE-2.0
7 *
8 * Unless required by applicable law or agreed to in writing, software
9 * distributed under the License is distributed on an "AS IS" BASIS,
10 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
11 * See the License for the specific language governing permissions and
12 * limitations under the License.
13 */
14
15 CREATE TABLE IF NOT EXISTS e2e_fallback_keys_json (
16 user_id TEXT NOT NULL, -- The user this fallback key is for.
17 device_id TEXT NOT NULL, -- The device this fallback key is for.
18 algorithm TEXT NOT NULL, -- Which algorithm this fallback key is for.
19 key_id TEXT NOT NULL, -- An id for suppressing duplicate uploads.
20 key_json TEXT NOT NULL, -- The key as a JSON blob.
21 used BOOLEAN NOT NULL DEFAULT FALSE, -- Whether the key has been used or not.
22 CONSTRAINT e2e_fallback_keys_json_uniqueness UNIQUE (user_id, device_id, algorithm)
23 );
0 /* Copyright 2020 The Matrix.org Foundation C.I.C
1 *
2 * Licensed under the Apache License, Version 2.0 (the "License");
3 * you may not use this file except in compliance with the License.
4 * You may obtain a copy of the License at
5 *
6 * http://www.apache.org/licenses/LICENSE-2.0
7 *
8 * Unless required by applicable law or agreed to in writing, software
9 * distributed under the License is distributed on an "AS IS" BASIS,
10 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
11 * See the License for the specific language governing permissions and
12 * limitations under the License.
13 */
14
15
16 -- A unique and immutable mapping between instance name and an integer ID. This
17 -- lets us refer to instances via a small ID in e.g. stream tokens, without
18 -- having to encode the full name.
19 CREATE TABLE IF NOT EXISTS instance_map (
20 instance_id SERIAL PRIMARY KEY,
21 instance_name TEXT NOT NULL
22 );
23
24 CREATE UNIQUE INDEX IF NOT EXISTS instance_map_idx ON instance_map(instance_name);
0 /* Copyright 2020 The Matrix.org Foundation C.I.C
1 *
2 * Licensed under the Apache License, Version 2.0 (the "License");
3 * you may not use this file except in compliance with the License.
4 * You may obtain a copy of the License at
5 *
6 * http://www.apache.org/licenses/LICENSE-2.0
7 *
8 * Unless required by applicable law or agreed to in writing, software
9 * distributed under the License is distributed on an "AS IS" BASIS,
10 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
11 * See the License for the specific language governing permissions and
12 * limitations under the License.
13 */
14
15
16 -- A map of recent events persisted with transaction IDs. Used to deduplicate
17 -- send event requests with the same transaction ID.
18 --
19 -- Note: transaction IDs are scoped to the room ID/user ID/access token that was
20 -- used to make the request.
21 --
22 -- Note: The foreign key constraints are ON DELETE CASCADE, as if we delete the
23 -- events or access token we don't want to try and de-duplicate the event.
24 CREATE TABLE IF NOT EXISTS event_txn_id (
25 event_id TEXT NOT NULL,
26 room_id TEXT NOT NULL,
27 user_id TEXT NOT NULL,
28 token_id BIGINT NOT NULL,
29 txn_id TEXT NOT NULL,
30 inserted_ts BIGINT NOT NULL,
31 FOREIGN KEY (event_id)
32 REFERENCES events (event_id) ON DELETE CASCADE,
33 FOREIGN KEY (token_id)
34 REFERENCES access_tokens (id) ON DELETE CASCADE
35 );
36
37 CREATE UNIQUE INDEX IF NOT EXISTS event_txn_id_event_id ON event_txn_id(event_id);
38 CREATE UNIQUE INDEX IF NOT EXISTS event_txn_id_txn_id ON event_txn_id(room_id, user_id, token_id, txn_id);
39 CREATE INDEX IF NOT EXISTS event_txn_id_ts ON event_txn_id(inserted_ts);
0 /* Copyright 2020 The Matrix.org Foundation C.I.C
1 *
2 * Licensed under the Apache License, Version 2.0 (the "License");
3 * you may not use this file except in compliance with the License.
4 * You may obtain a copy of the License at
5 *
6 * http://www.apache.org/licenses/LICENSE-2.0
7 *
8 * Unless required by applicable law or agreed to in writing, software
9 * distributed under the License is distributed on an "AS IS" BASIS,
10 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
11 * See the License for the specific language governing permissions and
12 * limitations under the License.
13 */
14
15 ALTER TABLE current_state_delta_stream ADD COLUMN instance_name TEXT;
16 ALTER TABLE ex_outlier_stream ADD COLUMN instance_name TEXT;
0 /* Copyright 2020 The Matrix.org Foundation C.I.C.
1 *
2 * Licensed under the Apache License, Version 2.0 (the "License");
3 * you may not use this file except in compliance with the License.
4 * You may obtain a copy of the License at
5 *
6 * http://www.apache.org/licenses/LICENSE-2.0
7 *
8 * Unless required by applicable law or agreed to in writing, software
9 * distributed under the License is distributed on an "AS IS" BASIS,
10 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
11 * See the License for the specific language governing permissions and
12 * limitations under the License.
13 */
14
15 -- Add new column to user_daily_visits to track user agent
16 ALTER TABLE user_daily_visits
17 ADD COLUMN user_agent TEXT;
0 /* Copyright 2020 The Matrix.org Foundation C.I.C
1 *
2 * Licensed under the Apache License, Version 2.0 (the "License");
3 * you may not use this file except in compliance with the License.
4 * You may obtain a copy of the License at
5 *
6 * http://www.apache.org/licenses/LICENSE-2.0
7 *
8 * Unless required by applicable law or agreed to in writing, software
9 * distributed under the License is distributed on an "AS IS" BASIS,
10 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
11 * See the License for the specific language governing permissions and
12 * limitations under the License.
13 */
14
15 ALTER TABLE application_services_state ADD COLUMN read_receipt_stream_id INT;
16 ALTER TABLE application_services_state ADD COLUMN presence_stream_id INT;
5252 )
5353 from synapse.storage.databases.main.events_worker import EventsWorkerStore
5454 from synapse.storage.engines import BaseDatabaseEngine, PostgresEngine
55 from synapse.storage.util.id_generators import MultiWriterIdGenerator
5556 from synapse.types import Collection, PersistedEventPosition, RoomStreamToken
57 from synapse.util.caches.descriptors import cached
5658 from synapse.util.caches.stream_change_cache import StreamChangeCache
5759
5860 if TYPE_CHECKING:
207209 )
208210
209211
212 def _filter_results(
213 lower_token: Optional[RoomStreamToken],
214 upper_token: Optional[RoomStreamToken],
215 instance_name: str,
216 topological_ordering: int,
217 stream_ordering: int,
218 ) -> bool:
219 """Returns True if the event persisted by the given instance at the given
220 topological/stream_ordering falls between the two tokens (taking a None
221 token to mean unbounded).
222
223 Used to filter results from fetching events in the DB against the given
224 tokens. This is necessary to handle the case where the tokens include
225 position maps, which we handle by fetching more than necessary from the DB
226 and then filtering (rather than attempting to construct a complicated SQL
227 query).
228 """
229
230 event_historical_tuple = (
231 topological_ordering,
232 stream_ordering,
233 )
234
235 if lower_token:
236 if lower_token.topological is not None:
237 # If these are historical tokens we compare the `(topological, stream)`
238 # tuples.
239 if event_historical_tuple <= lower_token.as_historical_tuple():
240 return False
241
242 else:
243 # If these are live tokens we compare the stream ordering against the
244 # writers stream position.
245 if stream_ordering <= lower_token.get_stream_pos_for_instance(
246 instance_name
247 ):
248 return False
249
250 if upper_token:
251 if upper_token.topological is not None:
252 if upper_token.as_historical_tuple() < event_historical_tuple:
253 return False
254 else:
255 if upper_token.get_stream_pos_for_instance(instance_name) < stream_ordering:
256 return False
257
258 return True
259
260
210261 def filter_to_clause(event_filter: Optional[Filter]) -> Tuple[str, List[str]]:
211262 # NB: This may create SQL clauses that don't optimise well (and we don't
212263 # have indices on all possible clauses). E.g. it may create
304355 raise NotImplementedError()
305356
306357 def get_room_max_token(self) -> RoomStreamToken:
307 return RoomStreamToken(None, self.get_room_max_stream_ordering())
358 """Get a `RoomStreamToken` that marks the current maximum persisted
359 position of the events stream. Useful to get a token that represents
360 "now".
361
362 The token returned is a "live" token that may have an instance_map
363 component.
364 """
365
366 min_pos = self._stream_id_gen.get_current_token()
367
368 positions = {}
369 if isinstance(self._stream_id_gen, MultiWriterIdGenerator):
370 # The `min_pos` is the minimum position that we know all instances
371 # have finished persisting to, so we only care about instances whose
372 # positions are ahead of that. (Instance positions can be behind the
373 # min position as there are times we can work out that the minimum
374 # position is ahead of the naive minimum across all current
375 # positions. See MultiWriterIdGenerator for details)
376 positions = {
377 i: p
378 for i, p in self._stream_id_gen.get_positions().items()
379 if p > min_pos
380 }
381
382 return RoomStreamToken(None, min_pos, positions)
308383
309384 async def get_room_events_stream_for_rooms(
310385 self,
403478 if from_key == to_key:
404479 return [], from_key
405480
406 from_id = from_key.stream
407 to_id = to_key.stream
408
409 has_changed = self._events_stream_cache.has_entity_changed(room_id, from_id)
481 has_changed = self._events_stream_cache.has_entity_changed(
482 room_id, from_key.stream
483 )
410484
411485 if not has_changed:
412486 return [], from_key
413487
414488 def f(txn):
415 sql = (
416 "SELECT event_id, stream_ordering FROM events WHERE"
417 " room_id = ?"
418 " AND not outlier"
419 " AND stream_ordering > ? AND stream_ordering <= ?"
420 " ORDER BY stream_ordering %s LIMIT ?"
421 ) % (order,)
422 txn.execute(sql, (room_id, from_id, to_id, limit))
423
424 rows = [_EventDictReturn(row[0], None, row[1]) for row in txn]
489 # To handle tokens with a non-empty instance_map we fetch more
490 # results than necessary and then filter down
491 min_from_id = from_key.stream
492 max_to_id = to_key.get_max_stream_pos()
493
494 sql = """
495 SELECT event_id, instance_name, topological_ordering, stream_ordering
496 FROM events
497 WHERE
498 room_id = ?
499 AND not outlier
500 AND stream_ordering > ? AND stream_ordering <= ?
501 ORDER BY stream_ordering %s LIMIT ?
502 """ % (
503 order,
504 )
505 txn.execute(sql, (room_id, min_from_id, max_to_id, 2 * limit))
506
507 rows = [
508 _EventDictReturn(event_id, None, stream_ordering)
509 for event_id, instance_name, topological_ordering, stream_ordering in txn
510 if _filter_results(
511 from_key,
512 to_key,
513 instance_name,
514 topological_ordering,
515 stream_ordering,
516 )
517 ][:limit]
425518 return rows
426519
427520 rows = await self.db_pool.runInteraction("get_room_events_stream_for_room", f)
430523 [r.event_id for r in rows], get_prev_content=True
431524 )
432525
433 self._set_before_and_after(ret, rows, topo_order=from_id is None)
526 self._set_before_and_after(ret, rows, topo_order=False)
434527
435528 if order.lower() == "desc":
436529 ret.reverse()
447540 async def get_membership_changes_for_user(
448541 self, user_id: str, from_key: RoomStreamToken, to_key: RoomStreamToken
449542 ) -> List[EventBase]:
450 from_id = from_key.stream
451 to_id = to_key.stream
452
453543 if from_key == to_key:
454544 return []
455545
456 if from_id:
546 if from_key:
457547 has_changed = self._membership_stream_cache.has_entity_changed(
458 user_id, int(from_id)
548 user_id, int(from_key.stream)
459549 )
460550 if not has_changed:
461551 return []
462552
463553 def f(txn):
464 sql = (
465 "SELECT m.event_id, stream_ordering FROM events AS e,"
466 " room_memberships AS m"
467 " WHERE e.event_id = m.event_id"
468 " AND m.user_id = ?"
469 " AND e.stream_ordering > ? AND e.stream_ordering <= ?"
470 " ORDER BY e.stream_ordering ASC"
471 )
472 txn.execute(sql, (user_id, from_id, to_id))
473
474 rows = [_EventDictReturn(row[0], None, row[1]) for row in txn]
554 # To handle tokens with a non-empty instance_map we fetch more
555 # results than necessary and then filter down
556 min_from_id = from_key.stream
557 max_to_id = to_key.get_max_stream_pos()
558
559 sql = """
560 SELECT m.event_id, instance_name, topological_ordering, stream_ordering
561 FROM events AS e, room_memberships AS m
562 WHERE e.event_id = m.event_id
563 AND m.user_id = ?
564 AND e.stream_ordering > ? AND e.stream_ordering <= ?
565 ORDER BY e.stream_ordering ASC
566 """
567 txn.execute(sql, (user_id, min_from_id, max_to_id,))
568
569 rows = [
570 _EventDictReturn(event_id, None, stream_ordering)
571 for event_id, instance_name, topological_ordering, stream_ordering in txn
572 if _filter_results(
573 from_key,
574 to_key,
575 instance_name,
576 topological_ordering,
577 stream_ordering,
578 )
579 ]
475580
476581 return rows
477582
545650
546651 async def get_room_event_before_stream_ordering(
547652 self, room_id: str, stream_ordering: int
548 ) -> Tuple[int, int, str]:
653 ) -> Optional[Tuple[int, int, str]]:
549654 """Gets details of the first event in a room at or before a stream ordering
550655
551656 Args:
587692 "_get_max_topological_txn", self._get_max_topological_txn, room_id
588693 )
589694 return "t%d-%d" % (topo, token)
590
591 async def get_stream_id_for_event(self, event_id: str) -> int:
592 """The stream ID for an event
593 Args:
594 event_id: The id of the event to look up a stream token for.
595 Raises:
596 StoreError if the event wasn't in the database.
597 Returns:
598 A stream ID.
599 """
600 return await self.db_pool.runInteraction(
601 "get_stream_id_for_event", self.get_stream_id_for_event_txn, event_id,
602 )
603695
604696 def get_stream_id_for_event_txn(
605697 self, txn: LoggingTransaction, event_id: str, allow_none=False,
9781070 else:
9791071 order = "ASC"
9801072
1073 # The bounds for the stream tokens are complicated by the fact
1074 # that we need to handle the instance_map part of the tokens. We do this
1075 # by fetching all events between the min stream token and the maximum
1076 # stream token (as returned by `RoomStreamToken.get_max_stream_pos`) and
1077 # then filtering the results.
1078 if from_token.topological is not None:
1079 from_bound = (
1080 from_token.as_historical_tuple()
1081 ) # type: Tuple[Optional[int], int]
1082 elif direction == "b":
1083 from_bound = (
1084 None,
1085 from_token.get_max_stream_pos(),
1086 )
1087 else:
1088 from_bound = (
1089 None,
1090 from_token.stream,
1091 )
1092
1093 to_bound = None # type: Optional[Tuple[Optional[int], int]]
1094 if to_token:
1095 if to_token.topological is not None:
1096 to_bound = to_token.as_historical_tuple()
1097 elif direction == "b":
1098 to_bound = (
1099 None,
1100 to_token.stream,
1101 )
1102 else:
1103 to_bound = (
1104 None,
1105 to_token.get_max_stream_pos(),
1106 )
1107
9811108 bounds = generate_pagination_where_clause(
9821109 direction=direction,
9831110 column_names=("topological_ordering", "stream_ordering"),
984 from_token=from_token.as_tuple(),
985 to_token=to_token.as_tuple() if to_token else None,
1111 from_token=from_bound,
1112 to_token=to_bound,
9861113 engine=self.database_engine,
9871114 )
9881115
9921119 bounds += " AND " + filter_clause
9931120 args.extend(filter_args)
9941121
995 args.append(int(limit))
1122 # We fetch more events as we'll filter the result set
1123 args.append(int(limit) * 2)
9961124
9971125 select_keywords = "SELECT"
9981126 join_clause = ""
10141142 select_keywords += "DISTINCT"
10151143
10161144 sql = """
1017 %(select_keywords)s event_id, topological_ordering, stream_ordering
1145 %(select_keywords)s
1146 event_id, instance_name,
1147 topological_ordering, stream_ordering
10181148 FROM events
10191149 %(join_clause)s
10201150 WHERE outlier = ? AND room_id = ? AND %(bounds)s
10291159
10301160 txn.execute(sql, args)
10311161
1032 rows = [_EventDictReturn(row[0], row[1], row[2]) for row in txn]
1162 # Filter the result set.
1163 rows = [
1164 _EventDictReturn(event_id, topological_ordering, stream_ordering)
1165 for event_id, instance_name, topological_ordering, stream_ordering in txn
1166 if _filter_results(
1167 lower_token=to_token if direction == "b" else from_token,
1168 upper_token=from_token if direction == "b" else to_token,
1169 instance_name=instance_name,
1170 topological_ordering=topological_ordering,
1171 stream_ordering=stream_ordering,
1172 )
1173 ][:limit]
10331174
10341175 if rows:
10351176 topo = rows[-1].topological_ordering
10941235
10951236 return (events, token)
10961237
1238 @cached()
1239 async def get_id_for_instance(self, instance_name: str) -> int:
1240 """Get a unique, immutable ID that corresponds to the given Synapse worker instance.
1241 """
1242
1243 def _get_id_for_instance_txn(txn):
1244 instance_id = self.db_pool.simple_select_one_onecol_txn(
1245 txn,
1246 table="instance_map",
1247 keyvalues={"instance_name": instance_name},
1248 retcol="instance_id",
1249 allow_none=True,
1250 )
1251 if instance_id is not None:
1252 return instance_id
1253
1254 # If we don't have an entry upsert one.
1255 #
1256 # We could do this before the first check, and rely on the cache for
1257 # efficiency, but each UPSERT causes the next ID to increment which
1258 # can quickly bloat the size of the generated IDs for new instances.
1259 self.db_pool.simple_upsert_txn(
1260 txn,
1261 table="instance_map",
1262 keyvalues={"instance_name": instance_name},
1263 values={},
1264 )
1265
1266 return self.db_pool.simple_select_one_onecol_txn(
1267 txn,
1268 table="instance_map",
1269 keyvalues={"instance_name": instance_name},
1270 retcol="instance_id",
1271 )
1272
1273 return await self.db_pool.runInteraction(
1274 "get_id_for_instance", _get_id_for_instance_txn
1275 )
1276
1277 @cached()
1278 async def get_name_from_instance_id(self, instance_id: int) -> str:
1279 """Get the instance name from an ID previously returned by
1280 `get_id_for_instance`.
1281 """
1282
1283 return await self.db_pool.simple_select_one_onecol(
1284 table="instance_map",
1285 keyvalues={"instance_id": instance_id},
1286 retcol="instance_name",
1287 desc="get_name_from_instance_id",
1288 )
1289
10971290
10981291 class StreamStore(StreamWorkerStore):
10991292 def get_room_max_stream_ordering(self) -> int:
1818
1919 from canonicaljson import encode_canonical_json
2020
21 from synapse.metrics.background_process_metrics import run_as_background_process
21 from synapse.metrics.background_process_metrics import wrap_as_background_process
2222 from synapse.storage._base import SQLBaseStore, db_to_json
2323 from synapse.storage.database import DatabasePool, LoggingTransaction
2424 from synapse.storage.engines import PostgresEngine, Sqlite3Engine
4242 SENTINEL = object()
4343
4444
45 class TransactionStore(SQLBaseStore):
45 class TransactionWorkerStore(SQLBaseStore):
46 def __init__(self, database: DatabasePool, db_conn, hs):
47 super().__init__(database, db_conn, hs)
48
49 if hs.config.run_background_tasks:
50 self._clock.looping_call(self._cleanup_transactions, 30 * 60 * 1000)
51
52 @wrap_as_background_process("cleanup_transactions")
53 async def _cleanup_transactions(self) -> None:
54 now = self._clock.time_msec()
55 month_ago = now - 30 * 24 * 60 * 60 * 1000
56
57 def _cleanup_transactions_txn(txn):
58 txn.execute("DELETE FROM received_transactions WHERE ts < ?", (month_ago,))
59
60 await self.db_pool.runInteraction(
61 "_cleanup_transactions", _cleanup_transactions_txn
62 )
63
64
65 class TransactionStore(TransactionWorkerStore):
4666 """A collection of queries for handling PDUs.
4767 """
4868
4969 def __init__(self, database: DatabasePool, db_conn, hs):
5070 super().__init__(database, db_conn, hs)
51
52 self._clock.looping_call(self._start_cleanup_transactions, 30 * 60 * 1000)
5371
5472 self._destination_retry_cache = ExpiringCache(
5573 cache_name="get_destination_retry_timings",
189207 """
190208
191209 self._destination_retry_cache.pop(destination, None)
192 return await self.db_pool.runInteraction(
193 "set_destination_retry_timings",
194 self._set_destination_retry_timings,
195 destination,
196 failure_ts,
197 retry_last_ts,
198 retry_interval,
199 )
200
201 def _set_destination_retry_timings(
210 if self.database_engine.can_native_upsert:
211 return await self.db_pool.runInteraction(
212 "set_destination_retry_timings",
213 self._set_destination_retry_timings_native,
214 destination,
215 failure_ts,
216 retry_last_ts,
217 retry_interval,
218 db_autocommit=True, # Safe as its a single upsert
219 )
220 else:
221 return await self.db_pool.runInteraction(
222 "set_destination_retry_timings",
223 self._set_destination_retry_timings_emulated,
224 destination,
225 failure_ts,
226 retry_last_ts,
227 retry_interval,
228 )
229
230 def _set_destination_retry_timings_native(
202231 self, txn, destination, failure_ts, retry_last_ts, retry_interval
203232 ):
204
205 if self.database_engine.can_native_upsert:
206 # Upsert retry time interval if retry_interval is zero (i.e. we're
207 # resetting it) or greater than the existing retry interval.
208
209 sql = """
210 INSERT INTO destinations (
211 destination, failure_ts, retry_last_ts, retry_interval
212 )
213 VALUES (?, ?, ?, ?)
214 ON CONFLICT (destination) DO UPDATE SET
215 failure_ts = EXCLUDED.failure_ts,
216 retry_last_ts = EXCLUDED.retry_last_ts,
217 retry_interval = EXCLUDED.retry_interval
218 WHERE
219 EXCLUDED.retry_interval = 0
220 OR destinations.retry_interval IS NULL
221 OR destinations.retry_interval < EXCLUDED.retry_interval
222 """
223
224 txn.execute(sql, (destination, failure_ts, retry_last_ts, retry_interval))
225
226 return
227
233 assert self.database_engine.can_native_upsert
234
235 # Upsert retry time interval if retry_interval is zero (i.e. we're
236 # resetting it) or greater than the existing retry interval.
237 #
238 # WARNING: This is executed in autocommit, so we shouldn't add any more
239 # SQL calls in here (without being very careful).
240 sql = """
241 INSERT INTO destinations (
242 destination, failure_ts, retry_last_ts, retry_interval
243 )
244 VALUES (?, ?, ?, ?)
245 ON CONFLICT (destination) DO UPDATE SET
246 failure_ts = EXCLUDED.failure_ts,
247 retry_last_ts = EXCLUDED.retry_last_ts,
248 retry_interval = EXCLUDED.retry_interval
249 WHERE
250 EXCLUDED.retry_interval = 0
251 OR destinations.retry_interval IS NULL
252 OR destinations.retry_interval < EXCLUDED.retry_interval
253 """
254
255 txn.execute(sql, (destination, failure_ts, retry_last_ts, retry_interval))
256
257 def _set_destination_retry_timings_emulated(
258 self, txn, destination, failure_ts, retry_last_ts, retry_interval
259 ):
228260 self.database_engine.lock_table(txn, "destinations")
229261
230262 # We need to be careful here as the data may have changed from under us
265297 },
266298 )
267299
268 def _start_cleanup_transactions(self):
269 return run_as_background_process(
270 "cleanup_transactions", self._cleanup_transactions
271 )
272
273 async def _cleanup_transactions(self) -> None:
274 now = self._clock.time_msec()
275 month_ago = now - 30 * 24 * 60 * 60 * 1000
276
277 def _cleanup_transactions_txn(txn):
278 txn.execute("DELETE FROM received_transactions WHERE ts < ?", (month_ago,))
279
280 await self.db_pool.runInteraction(
281 "_cleanup_transactions", _cleanup_transactions_txn
282 )
283
284300 async def store_destination_rooms_entries(
285301 self, destinations: Iterable[str], room_id: str, stream_ordering: int,
286302 ) -> None:
287287 )
288288 return [(row["user_agent"], row["ip"]) for row in rows]
289289
290
291 class UIAuthStore(UIAuthWorkerStore):
292290 async def delete_old_ui_auth_sessions(self, expiration_time: int) -> None:
293291 """
294292 Remove sessions which were last used earlier than the expiration time.
338336 iterable=session_ids,
339337 keyvalues={},
340338 )
339
340
341 class UIAuthStore(UIAuthWorkerStore):
342 pass
479479 user_id_tuples: iterable of 2-tuple of user IDs.
480480 """
481481
482 def _add_users_who_share_room_txn(txn):
483 self.db_pool.simple_upsert_many_txn(
484 txn,
485 table="users_who_share_private_rooms",
486 key_names=["user_id", "other_user_id", "room_id"],
487 key_values=[
488 (user_id, other_user_id, room_id)
489 for user_id, other_user_id in user_id_tuples
490 ],
491 value_names=(),
492 value_values=None,
493 )
494
495 await self.db_pool.runInteraction(
496 "add_users_who_share_room", _add_users_who_share_room_txn
482 await self.db_pool.simple_upsert_many(
483 table="users_who_share_private_rooms",
484 key_names=["user_id", "other_user_id", "room_id"],
485 key_values=[
486 (user_id, other_user_id, room_id)
487 for user_id, other_user_id in user_id_tuples
488 ],
489 value_names=(),
490 value_values=None,
491 desc="add_users_who_share_room",
497492 )
498493
499494 async def add_users_in_public_rooms(
507502 user_ids
508503 """
509504
510 def _add_users_in_public_rooms_txn(txn):
511
512 self.db_pool.simple_upsert_many_txn(
513 txn,
514 table="users_in_public_rooms",
515 key_names=["user_id", "room_id"],
516 key_values=[(user_id, room_id) for user_id in user_ids],
517 value_names=(),
518 value_values=None,
519 )
520
521 await self.db_pool.runInteraction(
522 "add_users_in_public_rooms", _add_users_in_public_rooms_txn
505 await self.db_pool.simple_upsert_many(
506 table="users_in_public_rooms",
507 key_names=["user_id", "room_id"],
508 key_values=[(user_id, room_id) for user_id in user_ids],
509 value_names=(),
510 value_values=None,
511 desc="add_users_in_public_rooms",
523512 )
524513
525514 async def delete_all_from_user_dir(self) -> None:
9595
9696 Returns:
9797 defer.Deferred: a deferred which will resolve once the events are
98 persisted. Runs its callbacks *without* a logcontext.
98 persisted. Runs its callbacks *without* a logcontext. The result
99 is the same as that returned by the callback passed to
100 `handle_queue`.
99101 """
100102 queue = self._event_persist_queues.setdefault(room_id, deque())
101103 if queue:
198200 self,
199201 events_and_contexts: Iterable[Tuple[EventBase, EventContext]],
200202 backfilled: bool = False,
201 ) -> RoomStreamToken:
203 ) -> Tuple[List[EventBase], RoomStreamToken]:
202204 """
203205 Write events to the database
204206 Args:
208210 which might update the current state etc.
209211
210212 Returns:
211 the stream ordering of the latest persisted event
213 List of events persisted, the current position room stream position.
214 The list of events persisted may not be the same as those passed in
215 if they were deduplicated due to an event already existing that
216 matched the transcation ID; the existing event is returned in such
217 a case.
212218 """
213219 partitioned = {} # type: Dict[str, List[Tuple[EventBase, EventContext]]]
214220 for event, ctx in events_and_contexts:
224230 for room_id in partitioned:
225231 self._maybe_start_persisting(room_id)
226232
227 await make_deferred_yieldable(
233 # Each deferred returns a map from event ID to existing event ID if the
234 # event was deduplicated. (The dict may also include other entries if
235 # the event was persisted in a batch with other events).
236 #
237 # Since we use `defer.gatherResults` we need to merge the returned list
238 # of dicts into one.
239 ret_vals = await make_deferred_yieldable(
228240 defer.gatherResults(deferreds, consumeErrors=True)
229241 )
230
231 return self.main_store.get_room_max_token()
242 replaced_events = {}
243 for d in ret_vals:
244 replaced_events.update(d)
245
246 events = []
247 for event, _ in events_and_contexts:
248 existing_event_id = replaced_events.get(event.event_id)
249 if existing_event_id:
250 events.append(await self.main_store.get_event(existing_event_id))
251 else:
252 events.append(event)
253
254 return (
255 events,
256 self.main_store.get_room_max_token(),
257 )
232258
233259 async def persist_event(
234260 self, event: EventBase, context: EventContext, backfilled: bool = False
235 ) -> Tuple[PersistedEventPosition, RoomStreamToken]:
261 ) -> Tuple[EventBase, PersistedEventPosition, RoomStreamToken]:
236262 """
237263 Returns:
238 The stream ordering of `event`, and the stream ordering of the
239 latest persisted event
264 The event, stream ordering of `event`, and the stream ordering of the
265 latest persisted event. The returned event may not match the given
266 event if it was deduplicated due to an existing event matching the
267 transaction ID.
240268 """
241269 deferred = self._event_persist_queue.add_to_queue(
242270 event.room_id, [(event, context)], backfilled=backfilled
244272
245273 self._maybe_start_persisting(event.room_id)
246274
247 await make_deferred_yieldable(deferred)
275 # The deferred returns a map from event ID to existing event ID if the
276 # event was deduplicated. (The dict may also include other entries if
277 # the event was persisted in a batch with other events.)
278 replaced_events = await make_deferred_yieldable(deferred)
279 replaced_event = replaced_events.get(event.event_id)
280 if replaced_event:
281 event = await self.main_store.get_event(replaced_event)
248282
249283 event_stream_id = event.internal_metadata.stream_ordering
284 # stream ordering should have been assigned by now
285 assert event_stream_id
250286
251287 pos = PersistedEventPosition(self._instance_name, event_stream_id)
252 return pos, self.main_store.get_room_max_token()
288 return event, pos, self.main_store.get_room_max_token()
253289
254290 def _maybe_start_persisting(self, room_id: str):
291 """Pokes the `_event_persist_queue` to start handling new items in the
292 queue, if not already in progress.
293
294 Causes the deferreds returned by `add_to_queue` to resolve with: a
295 dictionary of event ID to event ID we didn't persist as we already had
296 another event persisted with the same TXN ID.
297 """
298
255299 async def persisting_queue(item):
256300 with Measure(self._clock, "persist_events"):
257 await self._persist_events(
301 return await self._persist_events(
258302 item.events_and_contexts, backfilled=item.backfilled
259303 )
260304
264308 self,
265309 events_and_contexts: List[Tuple[EventBase, EventContext]],
266310 backfilled: bool = False,
267 ):
311 ) -> Dict[str, str]:
268312 """Calculates the change to current state and forward extremities, and
269313 persists the given events and with those updates.
270 """
314
315 Returns:
316 A dictionary of event ID to event ID we didn't persist as we already
317 had another event persisted with the same TXN ID.
318 """
319 replaced_events = {} # type: Dict[str, str]
271320 if not events_and_contexts:
272 return
321 return replaced_events
322
323 # Check if any of the events have a transaction ID that has already been
324 # persisted, and if so we don't persist it again.
325 #
326 # We should have checked this a long time before we get here, but it's
327 # possible that different send event requests race in such a way that
328 # they both pass the earlier checks. Checking here isn't racey as we can
329 # have only one `_persist_events` per room being called at a time.
330 replaced_events = await self.main_store.get_already_persisted_events(
331 (event for event, _ in events_and_contexts)
332 )
333
334 if replaced_events:
335 events_and_contexts = [
336 (e, ctx)
337 for e, ctx in events_and_contexts
338 if e.event_id not in replaced_events
339 ]
340
341 if not events_and_contexts:
342 return replaced_events
273343
274344 chunks = [
275345 events_and_contexts[x : x + 100]
438508
439509 await self._handle_potentially_left_users(potentially_left_users)
440510
511 return replaced_events
512
441513 async def _calculate_new_extremities(
442514 self,
443515 room_id: str,
1212 # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
1313 # See the License for the specific language governing permissions and
1414 # limitations under the License.
15
1615 import imp
1716 import logging
1817 import os
2322 import attr
2423
2524 from synapse.config.homeserver import HomeServerConfig
25 from synapse.storage.database import LoggingDatabaseConnection
2626 from synapse.storage.engines import BaseDatabaseEngine
2727 from synapse.storage.engines.postgres import PostgresEngine
28 from synapse.storage.types import Connection, Cursor
28 from synapse.storage.types import Cursor
2929 from synapse.types import Collection
3030
3131 logger = logging.getLogger(__name__)
6666
6767
6868 def prepare_database(
69 db_conn: Connection,
69 db_conn: LoggingDatabaseConnection,
7070 database_engine: BaseDatabaseEngine,
7171 config: Optional[HomeServerConfig],
7272 databases: Collection[str] = ["main", "state"],
8888 """
8989
9090 try:
91 cur = db_conn.cursor()
91 cur = db_conn.cursor(txn_name="prepare_database")
9292
9393 # sqlite does not automatically start transactions for DDL / SELECT statements,
9494 # so we start one before running anything. This ensures that any upgrades
257257 executescript(cur, entry.absolute_path)
258258
259259 cur.execute(
260 database_engine.convert_param_style(
261 "INSERT INTO schema_version (version, upgraded) VALUES (?,?)"
262 ),
260 "INSERT INTO schema_version (version, upgraded) VALUES (?,?)",
263261 (max_current_ver, False),
264262 )
265263
485483
486484 # Mark as done.
487485 cur.execute(
488 database_engine.convert_param_style(
489 "INSERT INTO applied_schema_deltas (version, file) VALUES (?,?)"
490 ),
486 "INSERT INTO applied_schema_deltas (version, file) VALUES (?,?)",
491487 (v, relative_path),
492488 )
493489
494490 cur.execute("DELETE FROM schema_version")
495491 cur.execute(
496 database_engine.convert_param_style(
497 "INSERT INTO schema_version (version, upgraded) VALUES (?,?)"
498 ),
492 "INSERT INTO schema_version (version, upgraded) VALUES (?,?)",
499493 (v, True),
500494 )
501495
531525 schemas to be applied
532526 """
533527 cur.execute(
534 database_engine.convert_param_style(
535 "SELECT file FROM applied_module_schemas WHERE module_name = ?"
536 ),
537 (modname,),
528 "SELECT file FROM applied_module_schemas WHERE module_name = ?", (modname,),
538529 )
539530 applied_deltas = {d for d, in cur}
540531 for (name, stream) in names_and_streams:
552543
553544 # Mark as done.
554545 cur.execute(
555 database_engine.convert_param_style(
556 "INSERT INTO applied_module_schemas (module_name, file) VALUES (?,?)"
557 ),
546 "INSERT INTO applied_module_schemas (module_name, file) VALUES (?,?)",
558547 (modname, name),
559548 )
560549
626615
627616 if current_version:
628617 txn.execute(
629 database_engine.convert_param_style(
630 "SELECT file FROM applied_schema_deltas WHERE version >= ?"
631 ),
618 "SELECT file FROM applied_schema_deltas WHERE version >= ?",
632619 (current_version,),
633620 )
634621 applied_deltas = [d for d, in txn]
1111 # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
1212 # See the License for the specific language governing permissions and
1313 # limitations under the License.
14 from typing import Any, Iterable, Iterator, List, Tuple
14 from typing import Any, Iterable, Iterator, List, Optional, Tuple
1515
1616 from typing_extensions import Protocol
1717
6060
6161 def rollback(self, *args, **kwargs) -> None:
6262 ...
63
64 def __enter__(self) -> "Connection":
65 ...
66
67 def __exit__(self, exc_type, exc_value, traceback) -> Optional[bool]:
68 ...
5454 """
5555 # debug logging for https://github.com/matrix-org/synapse/issues/7968
5656 logger.info("initialising stream generator for %s(%s)", table, column)
57 cur = db_conn.cursor()
57 cur = db_conn.cursor(txn_name="_load_current_id")
5858 if step == 1:
5959 cur.execute("SELECT MAX(%s) FROM %s" % (column, table))
6060 else:
269269 def _load_current_ids(
270270 self, db_conn, table: str, instance_column: str, id_column: str
271271 ):
272 cur = db_conn.cursor()
272 cur = db_conn.cursor(txn_name="_load_current_ids")
273273
274274 # Load the current positions of all writers for the stream.
275275 if self._writers:
283283 stream_name = ?
284284 AND instance_name != ALL(?)
285285 """
286 sql = self._db.engine.convert_param_style(sql)
287286 cur.execute(sql, (self._stream_name, self._writers))
288287
289288 sql = """
290289 SELECT instance_name, stream_id FROM stream_positions
291290 WHERE stream_name = ?
292291 """
293 sql = self._db.engine.convert_param_style(sql)
294
295292 cur.execute(sql, (self._stream_name,))
296293
297294 self._current_positions = {
340337 "instance": instance_column,
341338 "cmp": "<=" if self._positive else ">=",
342339 }
343 sql = self._db.engine.convert_param_style(sql)
344340 cur.execute(sql, (min_stream_id * self._return_factor,))
345341
346342 self._persisted_upto_position = min_stream_id
421417 self._unfinished_ids.discard(next_id)
422418 self._finished_ids.add(next_id)
423419
424 new_cur = None
420 new_cur = None # type: Optional[int]
425421
426422 if self._unfinished_ids:
427423 # If there are unfinished IDs then the new position will be the
526522 assert self._lock.locked()
527523
528524 heapq.heappush(self._known_persisted_positions, new_id)
525
526 # If we're a writer and we don't have any active writes we update our
527 # current position to the latest position seen. This allows the instance
528 # to report a recent position when asked, rather than a potentially old
529 # one (if this instance hasn't written anything for a while).
530 our_current_position = self._current_positions.get(self._instance_name)
531 if our_current_position and not self._unfinished_ids:
532 self._current_positions[self._instance_name] = max(
533 our_current_position, new_id
534 )
529535
530536 # We move the current min position up if the minimum current positions
531537 # of all instances is higher (since by definition all positions less
1616 import threading
1717 from typing import Callable, List, Optional
1818
19 from synapse.storage.database import LoggingDatabaseConnection
1920 from synapse.storage.engines import (
2021 BaseDatabaseEngine,
2122 IncorrectDatabaseSetup,
5253
5354 @abc.abstractmethod
5455 def check_consistency(
55 self, db_conn: Connection, table: str, id_column: str, positive: bool = True
56 self,
57 db_conn: LoggingDatabaseConnection,
58 table: str,
59 id_column: str,
60 positive: bool = True,
5661 ):
5762 """Should be called during start up to test that the current value of
5863 the sequence is greater than or equal to the maximum ID in the table.
8186 return [i for (i,) in txn]
8287
8388 def check_consistency(
84 self, db_conn: Connection, table: str, id_column: str, positive: bool = True
89 self,
90 db_conn: LoggingDatabaseConnection,
91 table: str,
92 id_column: str,
93 positive: bool = True,
8594 ):
86 txn = db_conn.cursor()
95 txn = db_conn.cursor(txn_name="sequence.check_consistency")
8796
8897 # First we get the current max ID from the table.
8998 table_sql = "SELECT GREATEST(%(agg)s(%(id)s), 0) FROM %(table)s" % {
116125 if max_stream_id > last_value:
117126 logger.warning(
118127 "Postgres sequence %s is behind table %s: %d < %d",
128 self._sequence_name,
129 table,
119130 last_value,
120131 max_stream_id,
121132 )
2121 TYPE_CHECKING,
2222 Any,
2323 Dict,
24 Iterable,
2425 Mapping,
2526 MutableMapping,
2627 Optional,
4243 if sys.version_info[:3] >= (3, 6, 0):
4344 from typing import Collection
4445 else:
45 from typing import Container, Iterable, Sized
46 from typing import Container, Sized
4647
4748 T_co = TypeVar("T_co", covariant=True)
4849
374375 return username.decode("ascii")
375376
376377
377 @attr.s(frozen=True, slots=True)
378 @attr.s(frozen=True, slots=True, cmp=False)
378379 class RoomStreamToken:
379380 """Tokens are positions between events. The token "s1" comes after event 1.
380381
396397 event it comes after. Historic tokens start with a "t" followed by the
397398 "topological_ordering" id of the event it comes after, followed by "-",
398399 followed by the "stream_ordering" id of the event it comes after.
400
401 There is also a third mode for live tokens where the token starts with "m",
402 which is sometimes used when using sharded event persisters. In this case
403 the events stream is considered to be a set of streams (one for each writer)
404 and the token encodes the vector clock of positions of each writer in their
405 respective streams.
406
407 The format of the token in such case is an initial integer min position,
408 followed by the mapping of instance ID to position separated by '.' and '~':
409
410 m{min_pos}~{writer1}.{pos1}~{writer2}.{pos2}. ...
411
412 The `min_pos` corresponds to the minimum position all writers have persisted
413 up to, and then only writers that are ahead of that position need to be
414 encoded. An example token is:
415
416 m56~2.58~3.59
417
418 Which corresponds to a set of three (or more writers) where instances 2 and
419 3 (these are instance IDs that can be looked up in the DB to fetch the more
420 commonly used instance names) are at positions 58 and 59 respectively, and
421 all other instances are at position 56.
422
423 Note: The `RoomStreamToken` cannot have both a topological part and an
424 instance map.
399425 """
400426
401427 topological = attr.ib(
403429 validator=attr.validators.optional(attr.validators.instance_of(int)),
404430 )
405431 stream = attr.ib(type=int, validator=attr.validators.instance_of(int))
432
433 instance_map = attr.ib(
434 type=Dict[str, int],
435 factory=dict,
436 validator=attr.validators.deep_mapping(
437 key_validator=attr.validators.instance_of(str),
438 value_validator=attr.validators.instance_of(int),
439 mapping_validator=attr.validators.instance_of(dict),
440 ),
441 )
442
443 def __attrs_post_init__(self):
444 """Validates that both `topological` and `instance_map` aren't set.
445 """
446
447 if self.instance_map and self.topological:
448 raise ValueError(
449 "Cannot set both 'topological' and 'instance_map' on 'RoomStreamToken'."
450 )
406451
407452 @classmethod
408453 async def parse(cls, store: "DataStore", string: str) -> "RoomStreamToken":
412457 if string[0] == "t":
413458 parts = string[1:].split("-", 1)
414459 return cls(topological=int(parts[0]), stream=int(parts[1]))
460 if string[0] == "m":
461 parts = string[1:].split("~")
462 stream = int(parts[0])
463
464 instance_map = {}
465 for part in parts[1:]:
466 key, value = part.split(".")
467 instance_id = int(key)
468 pos = int(value)
469
470 instance_name = await store.get_name_from_instance_id(instance_id)
471 instance_map[instance_name] = pos
472
473 return cls(topological=None, stream=stream, instance_map=instance_map,)
415474 except Exception:
416475 pass
417476 raise SynapseError(400, "Invalid token %r" % (string,))
435494
436495 max_stream = max(self.stream, other.stream)
437496
438 return RoomStreamToken(None, max_stream)
439
440 def as_tuple(self) -> Tuple[Optional[int], int]:
497 instance_map = {
498 instance: max(
499 self.instance_map.get(instance, self.stream),
500 other.instance_map.get(instance, other.stream),
501 )
502 for instance in set(self.instance_map).union(other.instance_map)
503 }
504
505 return RoomStreamToken(None, max_stream, instance_map)
506
507 def as_historical_tuple(self) -> Tuple[int, int]:
508 """Returns a tuple of `(topological, stream)` for historical tokens.
509
510 Raises if not an historical token (i.e. doesn't have a topological part).
511 """
512 if self.topological is None:
513 raise Exception(
514 "Cannot call `RoomStreamToken.as_historical_tuple` on live token"
515 )
516
441517 return (self.topological, self.stream)
518
519 def get_stream_pos_for_instance(self, instance_name: str) -> int:
520 """Get the stream position that the given writer was at at this token.
521
522 This only makes sense for "live" tokens that may have a vector clock
523 component, and so asserts that this is a "live" token.
524 """
525 assert self.topological is None
526
527 # If we don't have an entry for the instance we can assume that it was
528 # at `self.stream`.
529 return self.instance_map.get(instance_name, self.stream)
530
531 def get_max_stream_pos(self) -> int:
532 """Get the maximum stream position referenced in this token.
533
534 The corresponding "min" position is, by definition just `self.stream`.
535
536 This is used to handle tokens that have non-empty `instance_map`, and so
537 reference stream positions after the `self.stream` position.
538 """
539 return max(self.instance_map.values(), default=self.stream)
442540
443541 async def to_string(self, store: "DataStore") -> str:
444542 if self.topological is not None:
445543 return "t%d-%d" % (self.topological, self.stream)
544 elif self.instance_map:
545 entries = []
546 for name, pos in self.instance_map.items():
547 instance_id = await store.get_id_for_instance(name)
548 entries.append("{}.{}".format(instance_id, pos))
549
550 encoded_map = "~".join(entries)
551 return "m{}~{}".format(self.stream, encoded_map)
446552 else:
447553 return "s%d" % (self.stream,)
448554
534640 stream = attr.ib(type=int)
535641
536642 def persisted_after(self, token: RoomStreamToken) -> bool:
537 return token.stream < self.stream
643 return token.get_stream_pos_for_instance(self.instance_name) < self.stream
538644
539645 def to_room_stream_token(self) -> RoomStreamToken:
540646 """Converts the position to a room stream token such that events
1717 import re
1818
1919 import attr
20 from frozendict import frozendict
2021
2122 from twisted.internet import defer, task
2223
3031 raise ValueError("Invalid JSON value: '%s'" % val)
3132
3233
33 # Create a custom encoder to reduce the whitespace produced by JSON encoding and
34 # ensure that valid JSON is produced.
35 json_encoder = json.JSONEncoder(allow_nan=False, separators=(",", ":"))
34 def _handle_frozendict(obj):
35 """Helper for json_encoder. Makes frozendicts serializable by returning
36 the underlying dict
37 """
38 if type(obj) is frozendict:
39 # fishing the protected dict out of the object is a bit nasty,
40 # but we don't really want the overhead of copying the dict.
41 return obj._dict
42 raise TypeError(
43 "Object of type %s is not JSON serializable" % obj.__class__.__name__
44 )
45
46
47 # A custom JSON encoder which:
48 # * handles frozendicts
49 # * produces valid JSON (no NaNs etc)
50 # * reduces redundant whitespace
51 json_encoder = json.JSONEncoder(
52 allow_nan=False, separators=(",", ":"), default=_handle_frozendict
53 )
3654
3755 # Create a custom decoder to reject Python extensions to JSON.
3856 json_decoder = json.JSONDecoder(parse_constant=_reject_invalid_json)
1515
1616 import logging
1717 from sys import intern
18 from typing import Callable, Dict, Optional
18 from typing import Callable, Dict, Optional, Sized
1919
2020 import attr
2121 from prometheus_client.core import Gauge
9191 def register_cache(
9292 cache_type: str,
9393 cache_name: str,
94 cache,
94 cache: Sized,
9595 collect_callback: Optional[Callable] = None,
9696 resizable: bool = True,
9797 resize_callback: Optional[Callable] = None,
9999 """Register a cache object for metric collection and resizing.
100100
101101 Args:
102 cache_type
102 cache_type: a string indicating the "type" of the cache. This is used
103 only for deduplication so isn't too important provided it's constant.
103104 cache_name: name of the cache
104 cache: cache itself
105 cache: cache itself, which must implement __len__(), and may optionally implement
106 a max_size property
105107 collect_callback: If given, a function which is called during metric
106108 collection to update additional metrics.
107 resizable: Whether this cache supports being resized.
109 resizable: Whether this cache supports being resized, in which case either
110 resize_callback must be provided, or the cache must support set_max_size().
108111 resize_callback: A function which can be called to resize the cache.
109112
110113 Returns:
0 # -*- coding: utf-8 -*-
1 # Copyright 2015, 2016 OpenMarket Ltd
2 # Copyright 2018 New Vector Ltd
3 # Copyright 2020 The Matrix.org Foundation C.I.C.
4 #
5 # Licensed under the Apache License, Version 2.0 (the "License");
6 # you may not use this file except in compliance with the License.
7 # You may obtain a copy of the License at
8 #
9 # http://www.apache.org/licenses/LICENSE-2.0
10 #
11 # Unless required by applicable law or agreed to in writing, software
12 # distributed under the License is distributed on an "AS IS" BASIS,
13 # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
14 # See the License for the specific language governing permissions and
15 # limitations under the License.
16
17 import enum
18 import threading
19 from typing import (
20 Callable,
21 Generic,
22 Iterable,
23 MutableMapping,
24 Optional,
25 TypeVar,
26 Union,
27 cast,
28 )
29
30 from prometheus_client import Gauge
31
32 from twisted.internet import defer
33 from twisted.python import failure
34
35 from synapse.util.async_helpers import ObservableDeferred
36 from synapse.util.caches.lrucache import LruCache
37 from synapse.util.caches.treecache import TreeCache, iterate_tree_cache_entry
38
39 cache_pending_metric = Gauge(
40 "synapse_util_caches_cache_pending",
41 "Number of lookups currently pending for this cache",
42 ["name"],
43 )
44
45 T = TypeVar("T")
46 KT = TypeVar("KT")
47 VT = TypeVar("VT")
48
49
50 class _Sentinel(enum.Enum):
51 # defining a sentinel in this way allows mypy to correctly handle the
52 # type of a dictionary lookup.
53 sentinel = object()
54
55
56 class DeferredCache(Generic[KT, VT]):
57 """Wraps an LruCache, adding support for Deferred results.
58
59 It expects that each entry added with set() will be a Deferred; likewise get()
60 will return a Deferred.
61 """
62
63 __slots__ = (
64 "cache",
65 "thread",
66 "_pending_deferred_cache",
67 )
68
69 def __init__(
70 self,
71 name: str,
72 max_entries: int = 1000,
73 keylen: int = 1,
74 tree: bool = False,
75 iterable: bool = False,
76 apply_cache_factor_from_config: bool = True,
77 ):
78 """
79 Args:
80 name: The name of the cache
81 max_entries: Maximum amount of entries that the cache will hold
82 keylen: The length of the tuple used as the cache key. Ignored unless
83 `tree` is True.
84 tree: Use a TreeCache instead of a dict as the underlying cache type
85 iterable: If True, count each item in the cached object as an entry,
86 rather than each cached object
87 apply_cache_factor_from_config: Whether cache factors specified in the
88 config file affect `max_entries`
89 """
90 cache_type = TreeCache if tree else dict
91
92 # _pending_deferred_cache maps from the key value to a `CacheEntry` object.
93 self._pending_deferred_cache = (
94 cache_type()
95 ) # type: MutableMapping[KT, CacheEntry]
96
97 def metrics_cb():
98 cache_pending_metric.labels(name).set(len(self._pending_deferred_cache))
99
100 # cache is used for completed results and maps to the result itself, rather than
101 # a Deferred.
102 self.cache = LruCache(
103 max_size=max_entries,
104 keylen=keylen,
105 cache_name=name,
106 cache_type=cache_type,
107 size_callback=(lambda d: len(d)) if iterable else None,
108 metrics_collection_callback=metrics_cb,
109 apply_cache_factor_from_config=apply_cache_factor_from_config,
110 ) # type: LruCache[KT, VT]
111
112 self.thread = None # type: Optional[threading.Thread]
113
114 @property
115 def max_entries(self):
116 return self.cache.max_size
117
118 def check_thread(self):
119 expected_thread = self.thread
120 if expected_thread is None:
121 self.thread = threading.current_thread()
122 else:
123 if expected_thread is not threading.current_thread():
124 raise ValueError(
125 "Cache objects can only be accessed from the main thread"
126 )
127
128 def get(
129 self,
130 key: KT,
131 callback: Optional[Callable[[], None]] = None,
132 update_metrics: bool = True,
133 ) -> defer.Deferred:
134 """Looks the key up in the caches.
135
136 For symmetry with set(), this method does *not* follow the synapse logcontext
137 rules: the logcontext will not be cleared on return, and the Deferred will run
138 its callbacks in the sentinel context. In other words: wrap the result with
139 make_deferred_yieldable() before `await`ing it.
140
141 Args:
142 key:
143 callback: Gets called when the entry in the cache is invalidated
144 update_metrics (bool): whether to update the cache hit rate metrics
145
146 Returns:
147 A Deferred which completes with the result. Note that this may later fail
148 if there is an ongoing set() operation which later completes with a failure.
149
150 Raises:
151 KeyError if the key is not found in the cache
152 """
153 callbacks = [callback] if callback else []
154 val = self._pending_deferred_cache.get(key, _Sentinel.sentinel)
155 if val is not _Sentinel.sentinel:
156 val.callbacks.update(callbacks)
157 if update_metrics:
158 m = self.cache.metrics
159 assert m # we always have a name, so should always have metrics
160 m.inc_hits()
161 return val.deferred.observe()
162
163 val2 = self.cache.get(
164 key, _Sentinel.sentinel, callbacks=callbacks, update_metrics=update_metrics
165 )
166 if val2 is _Sentinel.sentinel:
167 raise KeyError()
168 else:
169 return defer.succeed(val2)
170
171 def get_immediate(
172 self, key: KT, default: T, update_metrics: bool = True
173 ) -> Union[VT, T]:
174 """If we have a *completed* cached value, return it."""
175 return self.cache.get(key, default, update_metrics=update_metrics)
176
177 def set(
178 self,
179 key: KT,
180 value: defer.Deferred,
181 callback: Optional[Callable[[], None]] = None,
182 ) -> defer.Deferred:
183 """Adds a new entry to the cache (or updates an existing one).
184
185 The given `value` *must* be a Deferred.
186
187 First any existing entry for the same key is invalidated. Then a new entry
188 is added to the cache for the given key.
189
190 Until the `value` completes, calls to `get()` for the key will also result in an
191 incomplete Deferred, which will ultimately complete with the same result as
192 `value`.
193
194 If `value` completes successfully, subsequent calls to `get()` will then return
195 a completed deferred with the same result. If it *fails*, the cache is
196 invalidated and subequent calls to `get()` will raise a KeyError.
197
198 If another call to `set()` happens before `value` completes, then (a) any
199 invalidation callbacks registered in the interim will be called, (b) any
200 `get()`s in the interim will continue to complete with the result from the
201 *original* `value`, (c) any future calls to `get()` will complete with the
202 result from the *new* `value`.
203
204 It is expected that `value` does *not* follow the synapse logcontext rules - ie,
205 if it is incomplete, it runs its callbacks in the sentinel context.
206
207 Args:
208 key: Key to be set
209 value: a deferred which will complete with a result to add to the cache
210 callback: An optional callback to be called when the entry is invalidated
211 """
212 if not isinstance(value, defer.Deferred):
213 raise TypeError("not a Deferred")
214
215 callbacks = [callback] if callback else []
216 self.check_thread()
217
218 existing_entry = self._pending_deferred_cache.pop(key, None)
219 if existing_entry:
220 existing_entry.invalidate()
221
222 # XXX: why don't we invalidate the entry in `self.cache` yet?
223
224 # we can save a whole load of effort if the deferred is ready.
225 if value.called:
226 result = value.result
227 if not isinstance(result, failure.Failure):
228 self.cache.set(key, result, callbacks)
229 return value
230
231 # otherwise, we'll add an entry to the _pending_deferred_cache for now,
232 # and add callbacks to add it to the cache properly later.
233
234 observable = ObservableDeferred(value, consumeErrors=True)
235 observer = observable.observe()
236 entry = CacheEntry(deferred=observable, callbacks=callbacks)
237
238 self._pending_deferred_cache[key] = entry
239
240 def compare_and_pop():
241 """Check if our entry is still the one in _pending_deferred_cache, and
242 if so, pop it.
243
244 Returns true if the entries matched.
245 """
246 existing_entry = self._pending_deferred_cache.pop(key, None)
247 if existing_entry is entry:
248 return True
249
250 # oops, the _pending_deferred_cache has been updated since
251 # we started our query, so we are out of date.
252 #
253 # Better put back whatever we took out. (We do it this way
254 # round, rather than peeking into the _pending_deferred_cache
255 # and then removing on a match, to make the common case faster)
256 if existing_entry is not None:
257 self._pending_deferred_cache[key] = existing_entry
258
259 return False
260
261 def cb(result):
262 if compare_and_pop():
263 self.cache.set(key, result, entry.callbacks)
264 else:
265 # we're not going to put this entry into the cache, so need
266 # to make sure that the invalidation callbacks are called.
267 # That was probably done when _pending_deferred_cache was
268 # updated, but it's possible that `set` was called without
269 # `invalidate` being previously called, in which case it may
270 # not have been. Either way, let's double-check now.
271 entry.invalidate()
272
273 def eb(_fail):
274 compare_and_pop()
275 entry.invalidate()
276
277 # once the deferred completes, we can move the entry from the
278 # _pending_deferred_cache to the real cache.
279 #
280 observer.addCallbacks(cb, eb)
281
282 # we return a new Deferred which will be called before any subsequent observers.
283 return observable.observe()
284
285 def prefill(self, key: KT, value: VT, callback: Callable[[], None] = None):
286 callbacks = [callback] if callback else []
287 self.cache.set(key, value, callbacks=callbacks)
288
289 def invalidate(self, key):
290 self.check_thread()
291 self.cache.pop(key, None)
292
293 # if we have a pending lookup for this key, remove it from the
294 # _pending_deferred_cache, which will (a) stop it being returned
295 # for future queries and (b) stop it being persisted as a proper entry
296 # in self.cache.
297 entry = self._pending_deferred_cache.pop(key, None)
298
299 # run the invalidation callbacks now, rather than waiting for the
300 # deferred to resolve.
301 if entry:
302 entry.invalidate()
303
304 def invalidate_many(self, key: KT):
305 self.check_thread()
306 if not isinstance(key, tuple):
307 raise TypeError("The cache key must be a tuple not %r" % (type(key),))
308 key = cast(KT, key)
309 self.cache.del_multi(key)
310
311 # if we have a pending lookup for this key, remove it from the
312 # _pending_deferred_cache, as above
313 entry_dict = self._pending_deferred_cache.pop(key, None)
314 if entry_dict is not None:
315 for entry in iterate_tree_cache_entry(entry_dict):
316 entry.invalidate()
317
318 def invalidate_all(self):
319 self.check_thread()
320 self.cache.clear()
321 for entry in self._pending_deferred_cache.values():
322 entry.invalidate()
323 self._pending_deferred_cache.clear()
324
325
326 class CacheEntry:
327 __slots__ = ["deferred", "callbacks", "invalidated"]
328
329 def __init__(
330 self, deferred: ObservableDeferred, callbacks: Iterable[Callable[[], None]]
331 ):
332 self.deferred = deferred
333 self.callbacks = set(callbacks)
334 self.invalidated = False
335
336 def invalidate(self):
337 if not self.invalidated:
338 self.invalidated = True
339 for callback in self.callbacks:
340 callback()
341 self.callbacks.clear()
1212 # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
1313 # See the License for the specific language governing permissions and
1414 # limitations under the License.
15
1615 import functools
1716 import inspect
1817 import logging
19 import threading
2018 from typing import Any, Callable, Generic, Optional, Tuple, TypeVar, Union, cast
2119 from weakref import WeakValueDictionary
2220
23 from prometheus_client import Gauge
24
2521 from twisted.internet import defer
2622
2723 from synapse.logging.context import make_deferred_yieldable, preserve_fn
2824 from synapse.util import unwrapFirstError
29 from synapse.util.async_helpers import ObservableDeferred
30 from synapse.util.caches.lrucache import LruCache
31 from synapse.util.caches.treecache import TreeCache, iterate_tree_cache_entry
32
33 from . import register_cache
25 from synapse.util.caches.deferred_cache import DeferredCache
3426
3527 logger = logging.getLogger(__name__)
3628
5244 # Note: This function signature is actually fiddled with by the synapse mypy
5345 # plugin to a) make it a bound method, and b) remove any `cache_context` arg.
5446 __call__ = None # type: F
55
56
57 cache_pending_metric = Gauge(
58 "synapse_util_caches_cache_pending",
59 "Number of lookups currently pending for this cache",
60 ["name"],
61 )
62
63 _CacheSentinel = object()
64
65
66 class CacheEntry:
67 __slots__ = ["deferred", "callbacks", "invalidated"]
68
69 def __init__(self, deferred, callbacks):
70 self.deferred = deferred
71 self.callbacks = set(callbacks)
72 self.invalidated = False
73
74 def invalidate(self):
75 if not self.invalidated:
76 self.invalidated = True
77 for callback in self.callbacks:
78 callback()
79 self.callbacks.clear()
80
81
82 class Cache:
83 __slots__ = (
84 "cache",
85 "name",
86 "keylen",
87 "thread",
88 "metrics",
89 "_pending_deferred_cache",
90 )
91
92 def __init__(
93 self,
94 name: str,
95 max_entries: int = 1000,
96 keylen: int = 1,
97 tree: bool = False,
98 iterable: bool = False,
99 apply_cache_factor_from_config: bool = True,
100 ):
101 """
102 Args:
103 name: The name of the cache
104 max_entries: Maximum amount of entries that the cache will hold
105 keylen: The length of the tuple used as the cache key
106 tree: Use a TreeCache instead of a dict as the underlying cache type
107 iterable: If True, count each item in the cached object as an entry,
108 rather than each cached object
109 apply_cache_factor_from_config: Whether cache factors specified in the
110 config file affect `max_entries`
111
112 Returns:
113 Cache
114 """
115 cache_type = TreeCache if tree else dict
116 self._pending_deferred_cache = cache_type()
117
118 self.cache = LruCache(
119 max_size=max_entries,
120 keylen=keylen,
121 cache_type=cache_type,
122 size_callback=(lambda d: len(d)) if iterable else None,
123 evicted_callback=self._on_evicted,
124 apply_cache_factor_from_config=apply_cache_factor_from_config,
125 )
126
127 self.name = name
128 self.keylen = keylen
129 self.thread = None # type: Optional[threading.Thread]
130 self.metrics = register_cache(
131 "cache",
132 name,
133 self.cache,
134 collect_callback=self._metrics_collection_callback,
135 )
136
137 @property
138 def max_entries(self):
139 return self.cache.max_size
140
141 def _on_evicted(self, evicted_count):
142 self.metrics.inc_evictions(evicted_count)
143
144 def _metrics_collection_callback(self):
145 cache_pending_metric.labels(self.name).set(len(self._pending_deferred_cache))
146
147 def check_thread(self):
148 expected_thread = self.thread
149 if expected_thread is None:
150 self.thread = threading.current_thread()
151 else:
152 if expected_thread is not threading.current_thread():
153 raise ValueError(
154 "Cache objects can only be accessed from the main thread"
155 )
156
157 def get(self, key, default=_CacheSentinel, callback=None, update_metrics=True):
158 """Looks the key up in the caches.
159
160 Args:
161 key(tuple)
162 default: What is returned if key is not in the caches. If not
163 specified then function throws KeyError instead
164 callback(fn): Gets called when the entry in the cache is invalidated
165 update_metrics (bool): whether to update the cache hit rate metrics
166
167 Returns:
168 Either an ObservableDeferred or the raw result
169 """
170 callbacks = [callback] if callback else []
171 val = self._pending_deferred_cache.get(key, _CacheSentinel)
172 if val is not _CacheSentinel:
173 val.callbacks.update(callbacks)
174 if update_metrics:
175 self.metrics.inc_hits()
176 return val.deferred
177
178 val = self.cache.get(key, _CacheSentinel, callbacks=callbacks)
179 if val is not _CacheSentinel:
180 self.metrics.inc_hits()
181 return val
182
183 if update_metrics:
184 self.metrics.inc_misses()
185
186 if default is _CacheSentinel:
187 raise KeyError()
188 else:
189 return default
190
191 def set(self, key, value, callback=None):
192 if not isinstance(value, defer.Deferred):
193 raise TypeError("not a Deferred")
194
195 callbacks = [callback] if callback else []
196 self.check_thread()
197 observable = ObservableDeferred(value, consumeErrors=True)
198 observer = observable.observe()
199 entry = CacheEntry(deferred=observable, callbacks=callbacks)
200
201 existing_entry = self._pending_deferred_cache.pop(key, None)
202 if existing_entry:
203 existing_entry.invalidate()
204
205 self._pending_deferred_cache[key] = entry
206
207 def compare_and_pop():
208 """Check if our entry is still the one in _pending_deferred_cache, and
209 if so, pop it.
210
211 Returns true if the entries matched.
212 """
213 existing_entry = self._pending_deferred_cache.pop(key, None)
214 if existing_entry is entry:
215 return True
216
217 # oops, the _pending_deferred_cache has been updated since
218 # we started our query, so we are out of date.
219 #
220 # Better put back whatever we took out. (We do it this way
221 # round, rather than peeking into the _pending_deferred_cache
222 # and then removing on a match, to make the common case faster)
223 if existing_entry is not None:
224 self._pending_deferred_cache[key] = existing_entry
225
226 return False
227
228 def cb(result):
229 if compare_and_pop():
230 self.cache.set(key, result, entry.callbacks)
231 else:
232 # we're not going to put this entry into the cache, so need
233 # to make sure that the invalidation callbacks are called.
234 # That was probably done when _pending_deferred_cache was
235 # updated, but it's possible that `set` was called without
236 # `invalidate` being previously called, in which case it may
237 # not have been. Either way, let's double-check now.
238 entry.invalidate()
239
240 def eb(_fail):
241 compare_and_pop()
242 entry.invalidate()
243
244 # once the deferred completes, we can move the entry from the
245 # _pending_deferred_cache to the real cache.
246 #
247 observer.addCallbacks(cb, eb)
248 return observable
249
250 def prefill(self, key, value, callback=None):
251 callbacks = [callback] if callback else []
252 self.cache.set(key, value, callbacks=callbacks)
253
254 def invalidate(self, key):
255 self.check_thread()
256 self.cache.pop(key, None)
257
258 # if we have a pending lookup for this key, remove it from the
259 # _pending_deferred_cache, which will (a) stop it being returned
260 # for future queries and (b) stop it being persisted as a proper entry
261 # in self.cache.
262 entry = self._pending_deferred_cache.pop(key, None)
263
264 # run the invalidation callbacks now, rather than waiting for the
265 # deferred to resolve.
266 if entry:
267 entry.invalidate()
268
269 def invalidate_many(self, key):
270 self.check_thread()
271 if not isinstance(key, tuple):
272 raise TypeError("The cache key must be a tuple not %r" % (type(key),))
273 self.cache.del_multi(key)
274
275 # if we have a pending lookup for this key, remove it from the
276 # _pending_deferred_cache, as above
277 entry_dict = self._pending_deferred_cache.pop(key, None)
278 if entry_dict is not None:
279 for entry in iterate_tree_cache_entry(entry_dict):
280 entry.invalidate()
281
282 def invalidate_all(self):
283 self.check_thread()
284 self.cache.clear()
285 for entry in self._pending_deferred_cache.values():
286 entry.invalidate()
287 self._pending_deferred_cache.clear()
28847
28948
29049 class _CacheDescriptorBase:
389148 self.iterable = iterable
390149
391150 def __get__(self, obj, owner):
392 cache = Cache(
151 cache = DeferredCache(
393152 name=self.orig.__name__,
394153 max_entries=self.max_entries,
395154 keylen=self.num_args,
396155 tree=self.tree,
397156 iterable=self.iterable,
398 )
157 ) # type: DeferredCache[CacheKey, Any]
399158
400159 def get_cache_key_gen(args, kwargs):
401160 """Given some args/kwargs return a generator that resolves into
441200
442201 cache_key = get_cache_key(args, kwargs)
443202
444 # Add our own `cache_context` to argument list if the wrapped function
445 # has asked for one
446 if self.add_cache_context:
447 kwargs["cache_context"] = _CacheContext.get_instance(cache, cache_key)
448
449203 try:
450 cached_result_d = cache.get(cache_key, callback=invalidate_callback)
451
452 if isinstance(cached_result_d, ObservableDeferred):
453 observer = cached_result_d.observe()
454 else:
455 observer = defer.succeed(cached_result_d)
456
204 ret = cache.get(cache_key, callback=invalidate_callback)
457205 except KeyError:
206 # Add our own `cache_context` to argument list if the wrapped function
207 # has asked for one
208 if self.add_cache_context:
209 kwargs["cache_context"] = _CacheContext.get_instance(
210 cache, cache_key
211 )
212
458213 ret = defer.maybeDeferred(preserve_fn(self.orig), obj, *args, **kwargs)
459
460 def onErr(f):
461 cache.invalidate(cache_key)
462 return f
463
464 ret.addErrback(onErr)
465
466 result_d = cache.set(cache_key, ret, callback=invalidate_callback)
467 observer = result_d.observe()
468
469 return make_deferred_yieldable(observer)
214 ret = cache.set(cache_key, ret, callback=invalidate_callback)
215
216 return make_deferred_yieldable(ret)
470217
471218 wrapped = cast(_CachedFunction, _wrapped)
472219
525272
526273 def __get__(self, obj, objtype=None):
527274 cached_method = getattr(obj, self.cached_method_name)
528 cache = cached_method.cache
275 cache = cached_method.cache # type: DeferredCache[CacheKey, Any]
529276 num_args = cached_method.num_args
530277
531278 @functools.wraps(self.orig)
565312 for arg in list_args:
566313 try:
567314 res = cache.get(arg_to_cache_key(arg), callback=invalidate_callback)
568 if not isinstance(res, ObservableDeferred):
569 results[arg] = res
570 elif not res.has_succeeded():
571 res = res.observe()
315 if not res.called:
572316 res.addCallback(update_results_dict, arg)
573317 cached_defers.append(res)
574318 else:
575 results[arg] = res.get_result()
319 results[arg] = res.result
576320 except KeyError:
577321 missing.add(arg)
578322
639383
640384 _cache_context_objects = (
641385 WeakValueDictionary()
642 ) # type: WeakValueDictionary[Tuple[Cache, CacheKey], _CacheContext]
643
644 def __init__(self, cache, cache_key): # type: (Cache, CacheKey) -> None
386 ) # type: WeakValueDictionary[Tuple[DeferredCache, CacheKey], _CacheContext]
387
388 def __init__(self, cache, cache_key): # type: (DeferredCache, CacheKey) -> None
645389 self._cache = cache
646390 self._cache_key = cache_key
647391
650394 self._cache.invalidate(self._cache_key)
651395
652396 @classmethod
653 def get_instance(cls, cache, cache_key): # type: (Cache, CacheKey) -> _CacheContext
397 def get_instance(
398 cls, cache, cache_key
399 ): # type: (DeferredCache, CacheKey) -> _CacheContext
654400 """Returns an instance constructed with the given arguments.
655401
656402 A new instance is only created if none already exists.
1111 # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
1212 # See the License for the specific language governing permissions and
1313 # limitations under the License.
14
14 import enum
1515 import logging
1616 import threading
1717 from collections import namedtuple
18 from typing import Any
1819
1920 from synapse.util.caches.lrucache import LruCache
20
21 from . import register_cache
2221
2322 logger = logging.getLogger(__name__)
2423
3938 return len(self.value)
4039
4140
41 class _Sentinel(enum.Enum):
42 # defining a sentinel in this way allows mypy to correctly handle the
43 # type of a dictionary lookup.
44 sentinel = object()
45
46
4247 class DictionaryCache:
4348 """Caches key -> dictionary lookups, supporting caching partial dicts, i.e.
4449 fetching a subset of dictionary keys for a particular key.
4550 """
4651
4752 def __init__(self, name, max_entries=1000):
48 self.cache = LruCache(max_size=max_entries, size_callback=len)
53 self.cache = LruCache(
54 max_size=max_entries, cache_name=name, size_callback=len
55 ) # type: LruCache[Any, DictionaryEntry]
4956
5057 self.name = name
5158 self.sequence = 0
5259 self.thread = None
53 # caches_by_name[name] = self.cache
54
55 class Sentinel:
56 __slots__ = []
57
58 self.sentinel = Sentinel()
59 self.metrics = register_cache("dictionary", name, self.cache)
6060
6161 def check_thread(self):
6262 expected_thread = self.thread
7979 Returns:
8080 DictionaryEntry
8181 """
82 entry = self.cache.get(key, self.sentinel)
83 if entry is not self.sentinel:
84 self.metrics.inc_hits()
85
82 entry = self.cache.get(key, _Sentinel.sentinel)
83 if entry is not _Sentinel.sentinel:
8684 if dict_keys is None:
8785 return DictionaryEntry(
8886 entry.full, entry.known_absent, dict(entry.value)
9492 {k: entry.value[k] for k in dict_keys if k in entry.value},
9593 )
9694
97 self.metrics.inc_misses()
9895 return DictionaryEntry(False, set(), {})
9996
10097 def invalidate(self, key):
1414
1515 import threading
1616 from functools import wraps
17 from typing import Callable, Optional, Type, Union
17 from typing import (
18 Any,
19 Callable,
20 Generic,
21 Iterable,
22 Optional,
23 Type,
24 TypeVar,
25 Union,
26 cast,
27 overload,
28 )
29
30 from typing_extensions import Literal
1831
1932 from synapse.config import cache as cache_config
33 from synapse.util.caches import CacheMetric, register_cache
2034 from synapse.util.caches.treecache import TreeCache
35
36 # Function type: the type used for invalidation callbacks
37 FT = TypeVar("FT", bound=Callable[..., Any])
38
39 # Key and Value type for the cache
40 KT = TypeVar("KT")
41 VT = TypeVar("VT")
42
43 # a general type var, distinct from either KT or VT
44 T = TypeVar("T")
2145
2246
2347 def enumerate_leaves(node, depth):
4064 self.callbacks = callbacks
4165
4266
43 class LruCache:
67 class LruCache(Generic[KT, VT]):
4468 """
45 Least-recently-used cache.
69 Least-recently-used cache, supporting prometheus metrics and invalidation callbacks.
70
4671 Supports del_multi only if cache_type=TreeCache
4772 If cache_type=TreeCache, all keys must be tuples.
48
49 Can also set callbacks on objects when getting/setting which are fired
50 when that key gets invalidated/evicted.
5173 """
5274
5375 def __init__(
5476 self,
5577 max_size: int,
78 cache_name: Optional[str] = None,
5679 keylen: int = 1,
5780 cache_type: Type[Union[dict, TreeCache]] = dict,
5881 size_callback: Optional[Callable] = None,
59 evicted_callback: Optional[Callable] = None,
82 metrics_collection_callback: Optional[Callable[[], None]] = None,
6083 apply_cache_factor_from_config: bool = True,
6184 ):
6285 """
6386 Args:
6487 max_size: The maximum amount of entries the cache can hold
6588
66 keylen: The length of the tuple used as the cache key
89 cache_name: The name of this cache, for the prometheus metrics. If unset,
90 no metrics will be reported on this cache.
91
92 keylen: The length of the tuple used as the cache key. Ignored unless
93 cache_type is `TreeCache`.
6794
6895 cache_type (type):
6996 type of underlying cache to be used. Typically one of dict
7198
7299 size_callback (func(V) -> int | None):
73100
74 evicted_callback (func(int)|None):
75 if not None, called on eviction with the size of the evicted
76 entry
101 metrics_collection_callback:
102 metrics collection callback. This is called early in the metrics
103 collection process, before any of the metrics registered with the
104 prometheus Registry are collected, so can be used to update any dynamic
105 metrics.
106
107 Ignored if cache_name is None.
77108
78109 apply_cache_factor_from_config (bool): If true, `max_size` will be
79110 multiplied by a cache factor derived from the homeserver config
92123 else:
93124 self.max_size = int(max_size)
94125
126 # register_cache might call our "set_cache_factor" callback; there's nothing to
127 # do yet when we get resized.
128 self._on_resize = None # type: Optional[Callable[[],None]]
129
130 if cache_name is not None:
131 metrics = register_cache(
132 "lru_cache",
133 cache_name,
134 self,
135 collect_callback=metrics_collection_callback,
136 ) # type: Optional[CacheMetric]
137 else:
138 metrics = None
139
140 # this is exposed for access from outside this class
141 self.metrics = metrics
142
95143 list_root = _Node(None, None, None, None)
96144 list_root.next_node = list_root
97145 list_root.prev_node = list_root
103151 todelete = list_root.prev_node
104152 evicted_len = delete_node(todelete)
105153 cache.pop(todelete.key, None)
106 if evicted_callback:
107 evicted_callback(evicted_len)
108
109 def synchronized(f):
154 if metrics:
155 metrics.inc_evictions(evicted_len)
156
157 def synchronized(f: FT) -> FT:
110158 @wraps(f)
111159 def inner(*args, **kwargs):
112160 with lock:
113161 return f(*args, **kwargs)
114162
115 return inner
163 return cast(FT, inner)
116164
117165 cached_cache_len = [0]
118166 if size_callback is not None:
166214 node.callbacks.clear()
167215 return deleted_len
168216
169 @synchronized
170 def cache_get(key, default=None, callbacks=[]):
217 @overload
218 def cache_get(
219 key: KT,
220 default: Literal[None] = None,
221 callbacks: Iterable[Callable[[], None]] = ...,
222 update_metrics: bool = ...,
223 ) -> Optional[VT]:
224 ...
225
226 @overload
227 def cache_get(
228 key: KT,
229 default: T,
230 callbacks: Iterable[Callable[[], None]] = ...,
231 update_metrics: bool = ...,
232 ) -> Union[T, VT]:
233 ...
234
235 @synchronized
236 def cache_get(
237 key: KT,
238 default: Optional[T] = None,
239 callbacks: Iterable[Callable[[], None]] = [],
240 update_metrics: bool = True,
241 ):
171242 node = cache.get(key, None)
172243 if node is not None:
173244 move_node_to_front(node)
174245 node.callbacks.update(callbacks)
246 if update_metrics and metrics:
247 metrics.inc_hits()
175248 return node.value
176249 else:
250 if update_metrics and metrics:
251 metrics.inc_misses()
177252 return default
178253
179254 @synchronized
180 def cache_set(key, value, callbacks=[]):
255 def cache_set(key: KT, value: VT, callbacks: Iterable[Callable[[], None]] = []):
181256 node = cache.get(key, None)
182257 if node is not None:
183258 # We sometimes store large objects, e.g. dicts, which cause
206281 evict()
207282
208283 @synchronized
209 def cache_set_default(key, value):
284 def cache_set_default(key: KT, value: VT) -> VT:
210285 node = cache.get(key, None)
211286 if node is not None:
212287 return node.value
215290 evict()
216291 return value
217292
218 @synchronized
219 def cache_pop(key, default=None):
293 @overload
294 def cache_pop(key: KT, default: Literal[None] = None) -> Optional[VT]:
295 ...
296
297 @overload
298 def cache_pop(key: KT, default: T) -> Union[T, VT]:
299 ...
300
301 @synchronized
302 def cache_pop(key: KT, default: Optional[T] = None):
220303 node = cache.get(key, None)
221304 if node:
222305 delete_node(node)
226309 return default
227310
228311 @synchronized
229 def cache_del_multi(key):
312 def cache_del_multi(key: KT) -> None:
230313 """
231314 This will only work if constructed with cache_type=TreeCache
232315 """
233316 popped = cache.pop(key)
234317 if popped is None:
235318 return
236 for leaf in enumerate_leaves(popped, keylen - len(key)):
319 for leaf in enumerate_leaves(popped, keylen - len(cast(tuple, key))):
237320 delete_node(leaf)
238321
239322 @synchronized
240 def cache_clear():
323 def cache_clear() -> None:
241324 list_root.next_node = list_root
242325 list_root.prev_node = list_root
243326 for node in cache.values():
248331 cached_cache_len[0] = 0
249332
250333 @synchronized
251 def cache_contains(key):
334 def cache_contains(key: KT) -> bool:
252335 return key in cache
253336
254337 self.sentinel = object()
338
339 # make sure that we clear out any excess entries after we get resized.
255340 self._on_resize = evict
341
256342 self.get = cache_get
257343 self.set = cache_set
258344 self.setdefault = cache_set_default
259345 self.pop = cache_pop
346 # `invalidate` is exposed for consistency with DeferredCache, so that it can be
347 # invalidated by the cache invalidation replication stream.
348 self.invalidate = cache_pop
260349 if cache_type is TreeCache:
261350 self.del_multi = cache_del_multi
262351 self.len = synchronized(cache_len)
300389 new_size = int(self._original_max_size * factor)
301390 if new_size != self.max_size:
302391 self.max_size = new_size
303 self._on_resize()
392 if self._on_resize:
393 self._on_resize()
304394 return True
305395 return False
1212 # See the License for the specific language governing permissions and
1313 # limitations under the License.
1414 import logging
15 from typing import TYPE_CHECKING, Any, Callable, Dict, Generic, Optional, TypeVar
1516
1617 from twisted.internet import defer
1718
1920 from synapse.util.async_helpers import ObservableDeferred
2021 from synapse.util.caches import register_cache
2122
23 if TYPE_CHECKING:
24 from synapse.app.homeserver import HomeServer
25
2226 logger = logging.getLogger(__name__)
2327
28 T = TypeVar("T")
2429
25 class ResponseCache:
30
31 class ResponseCache(Generic[T]):
2632 """
2733 This caches a deferred response. Until the deferred completes it will be
2834 returned from the cache. This means that if the client retries the request
3036 used rather than trying to compute a new response.
3137 """
3238
33 def __init__(self, hs, name, timeout_ms=0):
34 self.pending_result_cache = {} # Requests that haven't finished yet.
39 def __init__(self, hs: "HomeServer", name: str, timeout_ms: float = 0):
40 # Requests that haven't finished yet.
41 self.pending_result_cache = {} # type: Dict[T, ObservableDeferred]
3542
3643 self.clock = hs.get_clock()
3744 self.timeout_sec = timeout_ms / 1000.0
3946 self._name = name
4047 self._metrics = register_cache("response_cache", name, self, resizable=False)
4148
42 def size(self):
49 def size(self) -> int:
4350 return len(self.pending_result_cache)
4451
45 def __len__(self):
52 def __len__(self) -> int:
4653 return self.size()
4754
48 def get(self, key):
55 def get(self, key: T) -> Optional[defer.Deferred]:
4956 """Look up the given key.
5057
5158 Can return either a new Deferred (which also doesn't follow the synapse
5764 from an absent cache entry.
5865
5966 Args:
60 key (hashable):
67 key: key to get/set in the cache
6168
6269 Returns:
63 twisted.internet.defer.Deferred|None|E: None if there is no entry
64 for this key; otherwise either a deferred result or the result
65 itself.
70 None if there is no entry for this key; otherwise a deferred which
71 resolves to the result.
6672 """
6773 result = self.pending_result_cache.get(key)
6874 if result is not None:
7278 self._metrics.inc_misses()
7379 return None
7480
75 def set(self, key, deferred):
81 def set(self, key: T, deferred: defer.Deferred) -> defer.Deferred:
7682 """Set the entry for the given key to the given deferred.
7783
7884 *deferred* should run its callbacks in the sentinel logcontext (ie,
8490 result. You will probably want to make_deferred_yieldable the result.
8591
8692 Args:
87 key (hashable):
88 deferred (twisted.internet.defer.Deferred[T):
93 key: key to get/set in the cache
94 deferred: The deferred which resolves to the result.
8995
9096 Returns:
91 twisted.internet.defer.Deferred[T]|T: a new deferred, or the actual
92 result.
97 A new deferred which resolves to the actual result.
9398 """
9499 result = ObservableDeferred(deferred, consumeErrors=True)
95100 self.pending_result_cache[key] = result
106111 result.addBoth(remove)
107112 return result.observe()
108113
109 def wrap(self, key, callback, *args, **kwargs):
114 def wrap(
115 self, key: T, callback: "Callable[..., Any]", *args: Any, **kwargs: Any
116 ) -> defer.Deferred:
110117 """Wrap together a *get* and *set* call, taking care of logcontexts
111118
112119 First looks up the key in the cache, and if it is present makes it
117124
118125 Example usage:
119126
120 @defer.inlineCallbacks
121 def handle_request(request):
127 async def handle_request(request):
122128 # etc
123129 return result
124130
125 result = yield response_cache.wrap(
131 result = await response_cache.wrap(
126132 key,
127133 handle_request,
128134 request,
129135 )
130136
131137 Args:
132 key (hashable): key to get/set in the cache
138 key: key to get/set in the cache
133139
134 callback (callable): function to call if the key is not found in
140 callback: function to call if the key is not found in
135141 the cache
136142
137143 *args: positional parameters to pass to the callback, if it is used
139145 **kwargs: named parameters to pass to the callback, if it is used
140146
141147 Returns:
142 twisted.internet.defer.Deferred: yieldable result
148 Deferred which resolves to the result
143149 """
144150 result = self.get(key)
145151 if not result:
3333 self._data = {}
3434
3535 # the _CacheEntries, sorted by expiry time
36 self._expiry_list = SortedList()
36 self._expiry_list = SortedList() # type: SortedList[_CacheEntry]
3737
3838 self._timer = timer
3939
1111 # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
1212 # See the License for the specific language governing permissions and
1313 # limitations under the License.
14
15 import json
1614
1715 from frozendict import frozendict
1816
4846 pass
4947
5048 return o
51
52
53 def _handle_frozendict(obj):
54 """Helper for EventEncoder. Makes frozendicts serializable by returning
55 the underlying dict
56 """
57 if type(obj) is frozendict:
58 # fishing the protected dict out of the object is a bit nasty,
59 # but we don't really want the overhead of copying the dict.
60 return obj._dict
61 raise TypeError(
62 "Object of type %s is not JSON serializable" % obj.__class__.__name__
63 )
64
65
66 # A JSONEncoder which is capable of encoding frozendicts without barfing.
67 # Additionally reduce the whitespace produced by JSON encoding.
68 frozendict_json_encoder = json.JSONEncoder(
69 allow_nan=False, separators=(",", ":"), default=_handle_frozendict,
70 )
3535 try:
3636 provider_config = provider_class.parse_config(provider.get("config"))
3737 except Exception as e:
38 raise ConfigError("Failed to parse config for %r: %r" % (provider["module"], e))
38 raise ConfigError("Failed to parse config for %r: %s" % (provider["module"], e))
3939
4040 return provider_class, provider_config
4141
1515 import logging
1616 import operator
1717
18 from synapse.api.constants import EventTypes, Membership
18 from synapse.api.constants import AccountDataTypes, EventTypes, Membership
1919 from synapse.events.utils import prune_event
2020 from synapse.storage import Storage
2121 from synapse.storage.state import StateFilter
7676 )
7777
7878 ignore_dict_content = await storage.main.get_global_account_data_by_type_for_user(
79 "m.ignored_user_list", user_id
79 AccountDataTypes.IGNORED_USER_LIST, user_id
8080 )
8181
82 # FIXME: This will explode if people upload something incorrect.
83 ignore_list = frozenset(
84 ignore_dict_content.get("ignored_users", {}).keys()
85 if ignore_dict_content
86 else []
87 )
82 ignore_list = frozenset()
83 if ignore_dict_content:
84 ignored_users_dict = ignore_dict_content.get("ignored_users", {})
85 if isinstance(ignored_users_dict, dict):
86 ignore_list = frozenset(ignored_users_dict.keys())
8887
8988 erased_senders = await storage.main.are_users_erased((e.sender for e in events))
9089
1414
1515 import sys
1616
17 from twisted.internet import epollreactor
17 try:
18 from twisted.internet.epollreactor import EPollReactor as Reactor
19 except ImportError:
20 from twisted.internet.pollreactor import PollReactor as Reactor
1821 from twisted.internet.main import installReactor
1922
2023 from synapse.config.homeserver import HomeServerConfig
4043 config_obj = HomeServerConfig()
4144 config_obj.parse_config_dict(config, "", "")
4245
43 hs = await setup_test_homeserver(
46 hs = setup_test_homeserver(
4447 cleanup_tasks.append, config=config_obj, reactor=reactor, clock=clock
4548 )
4649 stor = hs.get_datastore()
6265 Instantiate and install a Twisted reactor suitable for testing (i.e. not the
6366 default global one).
6467 """
65 reactor = epollreactor.EPollReactor()
68 reactor = Reactor()
6669
6770 if "twisted.internet.reactor" in sys.modules:
6871 del sys.modules["twisted.internet.reactor"]
3333 # this requirement from the spec
3434 Inbound federation of state requires event_id as a mandatory paramater
3535
36 # Blacklisted until https://github.com/matrix-org/synapse/pull/6486 lands
37 Can upload self-signing keys
38
3936 # Blacklisted until MSC2753 is implemented
4037 Local users can peek into world_readable rooms by room ID
4138 We can't peek into rooms with shared history_visibility
1818
1919 from twisted.internet import defer
2020
21 import synapse.handlers.auth
2221 from synapse.api.auth import Auth
2322 from synapse.api.constants import UserTypes
2423 from synapse.api.errors import (
3534 from tests.utils import mock_getRawHeaders, setup_test_homeserver
3635
3736
38 class TestHandlers:
39 def __init__(self, hs):
40 self.auth_handler = synapse.handlers.auth.AuthHandler(hs)
41
42
4337 class AuthTestCase(unittest.TestCase):
4438 @defer.inlineCallbacks
4539 def setUp(self):
4640 self.state_handler = Mock()
4741 self.store = Mock()
4842
49 self.hs = yield setup_test_homeserver(self.addCleanup, handlers=None)
43 self.hs = yield setup_test_homeserver(self.addCleanup)
5044 self.hs.get_datastore = Mock(return_value=self.store)
51 self.hs.handlers = TestHandlers(self.hs)
45 self.hs.get_auth_handler().store = self.store
5246 self.auth = Auth(self.hs)
5347
5448 # AuthBlocking reads from the hs' config on initialization. We need to
282276 self.store.get_device = Mock(return_value=defer.succeed(None))
283277
284278 token = yield defer.ensureDeferred(
285 self.hs.handlers.auth_handler.get_access_token_for_user_id(
279 self.hs.get_auth_handler().get_access_token_for_user_id(
286280 USER_ID, "DEVICE", valid_until_ms=None
287281 )
288282 )
4949 self.mock_http_client.put_json = DeferredMockCallable()
5050
5151 hs = yield setup_test_homeserver(
52 self.addCleanup,
53 handlers=None,
54 http_client=self.mock_http_client,
55 keyring=Mock(),
52 self.addCleanup, http_client=self.mock_http_client, keyring=Mock(),
5653 )
5754
5855 self.filtering = hs.get_filtering()
2121 def make_homeserver(self, reactor, clock):
2222
2323 hs = self.setup_test_homeserver(
24 http_client=None, homeserverToUse=GenericWorkerServer
24 http_client=None, homeserver_to_use=GenericWorkerServer
2525 )
2626
2727 return hs
2525 class FederationReaderOpenIDListenerTests(HomeserverTestCase):
2626 def make_homeserver(self, reactor, clock):
2727 hs = self.setup_test_homeserver(
28 http_client=None, homeserverToUse=GenericWorkerServer
28 http_client=None, homeserver_to_use=GenericWorkerServer
2929 )
3030 return hs
3131
8383 class SynapseHomeserverOpenIDListenerTests(HomeserverTestCase):
8484 def make_homeserver(self, reactor, clock):
8585 hs = self.setup_test_homeserver(
86 http_client=None, homeserverToUse=SynapseHomeServer
86 http_client=None, homeserver_to_use=SynapseHomeServer
8787 )
8888 return hs
8989
5959 self.successResultOf(defer.ensureDeferred(self.txnctrl.send(service, events)))
6060
6161 self.store.create_appservice_txn.assert_called_once_with(
62 service=service, events=events # txn made and saved
62 service=service, events=events, ephemeral=[] # txn made and saved
6363 )
6464 self.assertEquals(0, len(self.txnctrl.recoverers)) # no recoverer made
6565 txn.complete.assert_called_once_with(self.store) # txn completed
8080 self.successResultOf(defer.ensureDeferred(self.txnctrl.send(service, events)))
8181
8282 self.store.create_appservice_txn.assert_called_once_with(
83 service=service, events=events # txn made and saved
83 service=service, events=events, ephemeral=[] # txn made and saved
8484 )
8585 self.assertEquals(0, txn.send.call_count) # txn not sent though
8686 self.assertEquals(0, txn.complete.call_count) # or completed
105105 self.successResultOf(defer.ensureDeferred(self.txnctrl.send(service, events)))
106106
107107 self.store.create_appservice_txn.assert_called_once_with(
108 service=service, events=events
108 service=service, events=events, ephemeral=[]
109109 )
110110 self.assertEquals(1, self.recoverer_fn.call_count) # recoverer made
111111 self.assertEquals(1, self.recoverer.recover.call_count) # and invoked
201201 # Expect the event to be sent immediately.
202202 service = Mock(id=4)
203203 event = Mock()
204 self.queuer.enqueue(service, event)
205 self.txn_ctrl.send.assert_called_once_with(service, [event])
204 self.queuer.enqueue_event(service, event)
205 self.txn_ctrl.send.assert_called_once_with(service, [event], [])
206206
207207 def test_send_single_event_with_queue(self):
208208 d = defer.Deferred()
209 self.txn_ctrl.send = Mock(side_effect=lambda x, y: make_deferred_yieldable(d))
209 self.txn_ctrl.send = Mock(
210 side_effect=lambda x, y, z: make_deferred_yieldable(d)
211 )
210212 service = Mock(id=4)
211213 event = Mock(event_id="first")
212214 event2 = Mock(event_id="second")
213215 event3 = Mock(event_id="third")
214216 # Send an event and don't resolve it just yet.
215 self.queuer.enqueue(service, event)
217 self.queuer.enqueue_event(service, event)
216218 # Send more events: expect send() to NOT be called multiple times.
217 self.queuer.enqueue(service, event2)
218 self.queuer.enqueue(service, event3)
219 self.txn_ctrl.send.assert_called_with(service, [event])
219 self.queuer.enqueue_event(service, event2)
220 self.queuer.enqueue_event(service, event3)
221 self.txn_ctrl.send.assert_called_with(service, [event], [])
220222 self.assertEquals(1, self.txn_ctrl.send.call_count)
221223 # Resolve the send event: expect the queued events to be sent
222224 d.callback(service)
223 self.txn_ctrl.send.assert_called_with(service, [event2, event3])
225 self.txn_ctrl.send.assert_called_with(service, [event2, event3], [])
224226 self.assertEquals(2, self.txn_ctrl.send.call_count)
225227
226228 def test_multiple_service_queues(self):
238240
239241 send_return_list = [srv_1_defer, srv_2_defer]
240242
241 def do_send(x, y):
243 def do_send(x, y, z):
242244 return make_deferred_yieldable(send_return_list.pop(0))
243245
244246 self.txn_ctrl.send = Mock(side_effect=do_send)
245247
246248 # send events for different ASes and make sure they are sent
247 self.queuer.enqueue(srv1, srv_1_event)
248 self.queuer.enqueue(srv1, srv_1_event2)
249 self.txn_ctrl.send.assert_called_with(srv1, [srv_1_event])
250 self.queuer.enqueue(srv2, srv_2_event)
251 self.queuer.enqueue(srv2, srv_2_event2)
252 self.txn_ctrl.send.assert_called_with(srv2, [srv_2_event])
249 self.queuer.enqueue_event(srv1, srv_1_event)
250 self.queuer.enqueue_event(srv1, srv_1_event2)
251 self.txn_ctrl.send.assert_called_with(srv1, [srv_1_event], [])
252 self.queuer.enqueue_event(srv2, srv_2_event)
253 self.queuer.enqueue_event(srv2, srv_2_event2)
254 self.txn_ctrl.send.assert_called_with(srv2, [srv_2_event], [])
253255
254256 # make sure callbacks for a service only send queued events for THAT
255257 # service
256258 srv_2_defer.callback(srv2)
257 self.txn_ctrl.send.assert_called_with(srv2, [srv_2_event2])
259 self.txn_ctrl.send.assert_called_with(srv2, [srv_2_event2], [])
258260 self.assertEquals(3, self.txn_ctrl.send.call_count)
261
262 def test_send_large_txns(self):
263 srv_1_defer = defer.Deferred()
264 srv_2_defer = defer.Deferred()
265 send_return_list = [srv_1_defer, srv_2_defer]
266
267 def do_send(x, y, z):
268 return make_deferred_yieldable(send_return_list.pop(0))
269
270 self.txn_ctrl.send = Mock(side_effect=do_send)
271
272 service = Mock(id=4, name="service")
273 event_list = [Mock(name="event%i" % (i + 1)) for i in range(200)]
274 for event in event_list:
275 self.queuer.enqueue_event(service, event)
276
277 # Expect the first event to be sent immediately.
278 self.txn_ctrl.send.assert_called_with(service, [event_list[0]], [])
279 srv_1_defer.callback(service)
280 # Then send the next 100 events
281 self.txn_ctrl.send.assert_called_with(service, event_list[1:101], [])
282 srv_2_defer.callback(service)
283 # Then the final 99 events
284 self.txn_ctrl.send.assert_called_with(service, event_list[101:], [])
285 self.assertEquals(3, self.txn_ctrl.send.call_count)
286
287 def test_send_single_ephemeral_no_queue(self):
288 # Expect the event to be sent immediately.
289 service = Mock(id=4, name="service")
290 event_list = [Mock(name="event")]
291 self.queuer.enqueue_ephemeral(service, event_list)
292 self.txn_ctrl.send.assert_called_once_with(service, [], event_list)
293
294 def test_send_multiple_ephemeral_no_queue(self):
295 # Expect the event to be sent immediately.
296 service = Mock(id=4, name="service")
297 event_list = [Mock(name="event1"), Mock(name="event2"), Mock(name="event3")]
298 self.queuer.enqueue_ephemeral(service, event_list)
299 self.txn_ctrl.send.assert_called_once_with(service, [], event_list)
300
301 def test_send_single_ephemeral_with_queue(self):
302 d = defer.Deferred()
303 self.txn_ctrl.send = Mock(
304 side_effect=lambda x, y, z: make_deferred_yieldable(d)
305 )
306 service = Mock(id=4)
307 event_list_1 = [Mock(event_id="event1"), Mock(event_id="event2")]
308 event_list_2 = [Mock(event_id="event3"), Mock(event_id="event4")]
309 event_list_3 = [Mock(event_id="event5"), Mock(event_id="event6")]
310
311 # Send an event and don't resolve it just yet.
312 self.queuer.enqueue_ephemeral(service, event_list_1)
313 # Send more events: expect send() to NOT be called multiple times.
314 self.queuer.enqueue_ephemeral(service, event_list_2)
315 self.queuer.enqueue_ephemeral(service, event_list_3)
316 self.txn_ctrl.send.assert_called_with(service, [], event_list_1)
317 self.assertEquals(1, self.txn_ctrl.send.call_count)
318 # Resolve txn_ctrl.send
319 d.callback(service)
320 # Expect the queued events to be sent
321 self.txn_ctrl.send.assert_called_with(service, [], event_list_2 + event_list_3)
322 self.assertEquals(2, self.txn_ctrl.send.call_count)
323
324 def test_send_large_txns_ephemeral(self):
325 d = defer.Deferred()
326 self.txn_ctrl.send = Mock(
327 side_effect=lambda x, y, z: make_deferred_yieldable(d)
328 )
329 # Expect the event to be sent immediately.
330 service = Mock(id=4, name="service")
331 first_chunk = [Mock(name="event%i" % (i + 1)) for i in range(100)]
332 second_chunk = [Mock(name="event%i" % (i + 101)) for i in range(50)]
333 event_list = first_chunk + second_chunk
334 self.queuer.enqueue_ephemeral(service, event_list)
335 self.txn_ctrl.send.assert_called_once_with(service, [], first_chunk)
336 d.callback(service)
337 self.txn_ctrl.send.assert_called_with(service, [], second_chunk)
338 self.assertEquals(2, self.txn_ctrl.send.call_count)
314314 class ServerKeyFetcherTestCase(unittest.HomeserverTestCase):
315315 def make_homeserver(self, reactor, clock):
316316 self.http_client = Mock()
317 hs = self.setup_test_homeserver(handlers=None, http_client=self.http_client)
317 hs = self.setup_test_homeserver(http_client=self.http_client)
318318 return hs
319319
320320 def test_get_keys_from_server(self):
394394 }
395395 ]
396396
397 return self.setup_test_homeserver(
398 handlers=None, http_client=self.http_client, config=config
399 )
397 return self.setup_test_homeserver(http_client=self.http_client, config=config)
400398
401399 def build_perspectives_response(
402400 self, server_name: str, signing_key: SigningKey, valid_until_ts: int,
3434 ]
3535
3636 def prepare(self, reactor, clock, hs):
37 self.admin_handler = hs.get_handlers().admin_handler
37 self.admin_handler = hs.get_admin_handler()
3838
3939 self.user1 = self.register_user("user1", "password")
4040 self.token1 = self.login("user1", "password")
1717 from twisted.internet import defer
1818
1919 from synapse.handlers.appservice import ApplicationServicesHandler
20 from synapse.types import RoomStreamToken
2021
2122 from tests.test_utils import make_awaitable
2223 from tests.utils import MockClock
6061 defer.succeed((0, [event])),
6162 defer.succeed((0, [])),
6263 ]
63 yield defer.ensureDeferred(self.handler.notify_interested_services(0))
64 yield defer.ensureDeferred(
65 self.handler.notify_interested_services(RoomStreamToken(None, 0))
66 )
6467 self.mock_scheduler.submit_event_for_as.assert_called_once_with(
6568 interested_service, event
6669 )
7982 defer.succeed((0, [event])),
8083 defer.succeed((0, [])),
8184 ]
82 yield defer.ensureDeferred(self.handler.notify_interested_services(0))
85 yield defer.ensureDeferred(
86 self.handler.notify_interested_services(RoomStreamToken(None, 0))
87 )
8388 self.mock_as_api.query_user.assert_called_once_with(services[0], user_id)
8489
8590 @defer.inlineCallbacks
96101 defer.succeed((0, [event])),
97102 defer.succeed((0, [])),
98103 ]
99 yield defer.ensureDeferred(self.handler.notify_interested_services(0))
104 yield defer.ensureDeferred(
105 self.handler.notify_interested_services(RoomStreamToken(None, 0))
106 )
100107 self.assertFalse(
101108 self.mock_as_api.query_user.called,
102109 "query_user called when it shouldn't have been.",
2020 import synapse
2121 import synapse.api.errors
2222 from synapse.api.errors import ResourceLimitError
23 from synapse.handlers.auth import AuthHandler
2423
2524 from tests import unittest
2625 from tests.test_utils import make_awaitable
2726 from tests.utils import setup_test_homeserver
2827
2928
30 class AuthHandlers:
31 def __init__(self, hs):
32 self.auth_handler = AuthHandler(hs)
33
34
3529 class AuthTestCase(unittest.TestCase):
3630 @defer.inlineCallbacks
3731 def setUp(self):
38 self.hs = yield setup_test_homeserver(self.addCleanup, handlers=None)
39 self.hs.handlers = AuthHandlers(self.hs)
40 self.auth_handler = self.hs.handlers.auth_handler
32 self.hs = yield setup_test_homeserver(self.addCleanup)
33 self.auth_handler = self.hs.get_auth_handler()
4134 self.macaroon_generator = self.hs.get_macaroon_generator()
4235
4336 # MAU tests
00 # -*- coding: utf-8 -*-
11 # Copyright 2016 OpenMarket Ltd
22 # Copyright 2018 New Vector Ltd
3 # Copyright 2020 The Matrix.org Foundation C.I.C.
34 #
45 # Licensed under the Apache License, Version 2.0 (the "License");
56 # you may not use this file except in compliance with the License.
223224 )
224225 )
225226 self.reactor.advance(1000)
227
228
229 class DehydrationTestCase(unittest.HomeserverTestCase):
230 def make_homeserver(self, reactor, clock):
231 hs = self.setup_test_homeserver("server", http_client=None)
232 self.handler = hs.get_device_handler()
233 self.registration = hs.get_registration_handler()
234 self.auth = hs.get_auth()
235 self.store = hs.get_datastore()
236 return hs
237
238 def test_dehydrate_and_rehydrate_device(self):
239 user_id = "@boris:dehydration"
240
241 self.get_success(self.store.register_user(user_id, "foobar"))
242
243 # First check if we can store and fetch a dehydrated device
244 stored_dehydrated_device_id = self.get_success(
245 self.handler.store_dehydrated_device(
246 user_id=user_id,
247 device_data={"device_data": {"foo": "bar"}},
248 initial_device_display_name="dehydrated device",
249 )
250 )
251
252 retrieved_device_id, device_data = self.get_success(
253 self.handler.get_dehydrated_device(user_id=user_id)
254 )
255
256 self.assertEqual(retrieved_device_id, stored_dehydrated_device_id)
257 self.assertEqual(device_data, {"device_data": {"foo": "bar"}})
258
259 # Create a new login for the user and dehydrated the device
260 device_id, access_token = self.get_success(
261 self.registration.register_device(
262 user_id=user_id, device_id=None, initial_display_name="new device",
263 )
264 )
265
266 # Trying to claim a nonexistent device should throw an error
267 self.get_failure(
268 self.handler.rehydrate_device(
269 user_id=user_id,
270 access_token=access_token,
271 device_id="not the right device ID",
272 ),
273 synapse.api.errors.NotFoundError,
274 )
275
276 # dehydrating the right devices should succeed and change our device ID
277 # to the dehydrated device's ID
278 res = self.get_success(
279 self.handler.rehydrate_device(
280 user_id=user_id,
281 access_token=access_token,
282 device_id=retrieved_device_id,
283 )
284 )
285
286 self.assertEqual(res, {"success": True})
287
288 # make sure that our device ID has changed
289 user_info = self.get_success(self.auth.get_user_by_access_token(access_token))
290
291 self.assertEqual(user_info["device_id"], retrieved_device_id)
292
293 # make sure the device has the display name that was set from the login
294 res = self.get_success(self.handler.get_device(user_id, retrieved_device_id))
295
296 self.assertEqual(res["display_name"], "new device")
297
298 # make sure that the device ID that we were initially assigned no longer exists
299 self.get_failure(
300 self.handler.get_device(user_id, device_id),
301 synapse.api.errors.NotFoundError,
302 )
303
304 # make sure that there's no device available for dehydrating now
305 ret = self.get_success(self.handler.get_dehydrated_device(user_id=user_id))
306
307 self.assertIsNone(ret)
4747 federation_registry=self.mock_registry,
4848 )
4949
50 self.handler = hs.get_handlers().directory_handler
50 self.handler = hs.get_directory_handler()
5151
5252 self.store = hs.get_datastore()
5353
109109 ]
110110
111111 def prepare(self, reactor, clock, hs):
112 self.handler = hs.get_handlers().directory_handler
112 self.handler = hs.get_directory_handler()
113113
114114 # Create user
115115 self.admin_user = self.register_user("admin", "pass", admin=True)
172172
173173 def prepare(self, reactor, clock, hs):
174174 self.store = hs.get_datastore()
175 self.handler = hs.get_handlers().directory_handler
175 self.handler = hs.get_directory_handler()
176176 self.state_handler = hs.get_state_handler()
177177
178178 # Create user
288288
289289 def prepare(self, reactor, clock, hs):
290290 self.store = hs.get_datastore()
291 self.handler = hs.get_handlers().directory_handler
291 self.handler = hs.get_directory_handler()
292292 self.state_handler = hs.get_state_handler()
293293
294294 # Create user
441441 self.assertEquals(200, channel.code, channel.result)
442442
443443 self.room_list_handler = hs.get_room_list_handler()
444 self.directory_handler = hs.get_handlers().directory_handler
444 self.directory_handler = hs.get_directory_handler()
445445
446446 return hs
447447
3232 super().__init__(*args, **kwargs)
3333 self.hs = None # type: synapse.server.HomeServer
3434 self.handler = None # type: synapse.handlers.e2e_keys.E2eKeysHandler
35 self.store = None # type: synapse.storage.Storage
3536
3637 @defer.inlineCallbacks
3738 def setUp(self):
3839 self.hs = yield utils.setup_test_homeserver(
39 self.addCleanup, handlers=None, federation_client=mock.Mock()
40 self.addCleanup, federation_client=mock.Mock()
4041 )
4142 self.handler = synapse.handlers.e2e_keys.E2eKeysHandler(self.hs)
43 self.store = self.hs.get_datastore()
4244
4345 @defer.inlineCallbacks
4446 def test_query_local_devices_no_devices(self):
168170 "failures": {},
169171 "one_time_keys": {local_user: {device_id: {"alg1:k1": "key1"}}},
170172 },
173 )
174
175 @defer.inlineCallbacks
176 def test_fallback_key(self):
177 local_user = "@boris:" + self.hs.hostname
178 device_id = "xyz"
179 fallback_key = {"alg1:k1": "key1"}
180 otk = {"alg1:k2": "key2"}
181
182 # we shouldn't have any unused fallback keys yet
183 res = yield defer.ensureDeferred(
184 self.store.get_e2e_unused_fallback_key_types(local_user, device_id)
185 )
186 self.assertEqual(res, [])
187
188 yield defer.ensureDeferred(
189 self.handler.upload_keys_for_user(
190 local_user,
191 device_id,
192 {"org.matrix.msc2732.fallback_keys": fallback_key},
193 )
194 )
195
196 # we should now have an unused alg1 key
197 res = yield defer.ensureDeferred(
198 self.store.get_e2e_unused_fallback_key_types(local_user, device_id)
199 )
200 self.assertEqual(res, ["alg1"])
201
202 # claiming an OTK when no OTKs are available should return the fallback
203 # key
204 res = yield defer.ensureDeferred(
205 self.handler.claim_one_time_keys(
206 {"one_time_keys": {local_user: {device_id: "alg1"}}}, timeout=None
207 )
208 )
209 self.assertEqual(
210 res,
211 {"failures": {}, "one_time_keys": {local_user: {device_id: fallback_key}}},
212 )
213
214 # we shouldn't have any unused fallback keys again
215 res = yield defer.ensureDeferred(
216 self.store.get_e2e_unused_fallback_key_types(local_user, device_id)
217 )
218 self.assertEqual(res, [])
219
220 # claiming an OTK again should return the same fallback key
221 res = yield defer.ensureDeferred(
222 self.handler.claim_one_time_keys(
223 {"one_time_keys": {local_user: {device_id: "alg1"}}}, timeout=None
224 )
225 )
226 self.assertEqual(
227 res,
228 {"failures": {}, "one_time_keys": {local_user: {device_id: fallback_key}}},
229 )
230
231 # if the user uploads a one-time key, the next claim should fetch the
232 # one-time key, and then go back to the fallback
233 yield defer.ensureDeferred(
234 self.handler.upload_keys_for_user(
235 local_user, device_id, {"one_time_keys": otk}
236 )
237 )
238
239 res = yield defer.ensureDeferred(
240 self.handler.claim_one_time_keys(
241 {"one_time_keys": {local_user: {device_id: "alg1"}}}, timeout=None
242 )
243 )
244 self.assertEqual(
245 res, {"failures": {}, "one_time_keys": {local_user: {device_id: otk}}},
246 )
247
248 res = yield defer.ensureDeferred(
249 self.handler.claim_one_time_keys(
250 {"one_time_keys": {local_user: {device_id: "alg1"}}}, timeout=None
251 )
252 )
253 self.assertEqual(
254 res,
255 {"failures": {}, "one_time_keys": {local_user: {device_id: fallback_key}}},
171256 )
172257
173258 @defer.inlineCallbacks
5353 @defer.inlineCallbacks
5454 def setUp(self):
5555 self.hs = yield utils.setup_test_homeserver(
56 self.addCleanup, handlers=None, replication_layer=mock.Mock()
56 self.addCleanup, replication_layer=mock.Mock()
5757 )
5858 self.handler = synapse.handlers.e2e_room_keys.E2eRoomKeysHandler(self.hs)
5959 self.local_user = "@boris:" + self.hs.hostname
3737
3838 def make_homeserver(self, reactor, clock):
3939 hs = self.setup_test_homeserver(http_client=None)
40 self.handler = hs.get_handlers().federation_handler
40 self.handler = hs.get_federation_handler()
4141 self.store = hs.get_datastore()
4242 return hs
4343
0 # -*- coding: utf-8 -*-
1 # Copyright 2020 The Matrix.org Foundation C.I.C.
2 #
3 # Licensed under the Apache License, Version 2.0 (the "License");
4 # you may not use this file except in compliance with the License.
5 # You may obtain a copy of the License at
6 #
7 # http://www.apache.org/licenses/LICENSE-2.0
8 #
9 # Unless required by applicable law or agreed to in writing, software
10 # distributed under the License is distributed on an "AS IS" BASIS,
11 # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 # See the License for the specific language governing permissions and
13 # limitations under the License.
14 import logging
15 from typing import Tuple
16
17 from synapse.api.constants import EventTypes
18 from synapse.events import EventBase
19 from synapse.events.snapshot import EventContext
20 from synapse.rest import admin
21 from synapse.rest.client.v1 import login, room
22 from synapse.types import create_requester
23 from synapse.util.stringutils import random_string
24
25 from tests import unittest
26
27 logger = logging.getLogger(__name__)
28
29
30 class EventCreationTestCase(unittest.HomeserverTestCase):
31 servlets = [
32 admin.register_servlets,
33 login.register_servlets,
34 room.register_servlets,
35 ]
36
37 def prepare(self, reactor, clock, hs):
38 self.handler = self.hs.get_event_creation_handler()
39 self.persist_event_storage = self.hs.get_storage().persistence
40
41 self.user_id = self.register_user("tester", "foobar")
42 self.access_token = self.login("tester", "foobar")
43 self.room_id = self.helper.create_room_as(self.user_id, tok=self.access_token)
44
45 self.info = self.get_success(
46 self.hs.get_datastore().get_user_by_access_token(self.access_token,)
47 )
48 self.token_id = self.info["token_id"]
49
50 self.requester = create_requester(self.user_id, access_token_id=self.token_id)
51
52 def _create_duplicate_event(self, txn_id: str) -> Tuple[EventBase, EventContext]:
53 """Create a new event with the given transaction ID. All events produced
54 by this method will be considered duplicates.
55 """
56
57 # We create a new event with a random body, as otherwise we'll produce
58 # *exactly* the same event with the same hash, and so same event ID.
59 return self.get_success(
60 self.handler.create_event(
61 self.requester,
62 {
63 "type": EventTypes.Message,
64 "room_id": self.room_id,
65 "sender": self.requester.user.to_string(),
66 "content": {"msgtype": "m.text", "body": random_string(5)},
67 },
68 txn_id=txn_id,
69 )
70 )
71
72 def test_duplicated_txn_id(self):
73 """Test that attempting to handle/persist an event with a transaction ID
74 that has already been persisted correctly returns the old event and does
75 *not* produce duplicate messages.
76 """
77
78 txn_id = "something_suitably_random"
79
80 event1, context = self._create_duplicate_event(txn_id)
81
82 ret_event1 = self.get_success(
83 self.handler.handle_new_client_event(self.requester, event1, context)
84 )
85 stream_id1 = ret_event1.internal_metadata.stream_ordering
86
87 self.assertEqual(event1.event_id, ret_event1.event_id)
88
89 event2, context = self._create_duplicate_event(txn_id)
90
91 # We want to test that the deduplication at the persit event end works,
92 # so we want to make sure we test with different events.
93 self.assertNotEqual(event1.event_id, event2.event_id)
94
95 ret_event2 = self.get_success(
96 self.handler.handle_new_client_event(self.requester, event2, context)
97 )
98 stream_id2 = ret_event2.internal_metadata.stream_ordering
99
100 # Assert that the returned values match those from the initial event
101 # rather than the new one.
102 self.assertEqual(ret_event1.event_id, ret_event2.event_id)
103 self.assertEqual(stream_id1, stream_id2)
104
105 # Let's test that calling `persist_event` directly also does the right
106 # thing.
107 event3, context = self._create_duplicate_event(txn_id)
108 self.assertNotEqual(event1.event_id, event3.event_id)
109
110 ret_event3, event_pos3, _ = self.get_success(
111 self.persist_event_storage.persist_event(event3, context)
112 )
113
114 # Assert that the returned values match those from the initial event
115 # rather than the new one.
116 self.assertEqual(ret_event1.event_id, ret_event3.event_id)
117 self.assertEqual(stream_id1, event_pos3.stream)
118
119 # Let's test that calling `persist_events` directly also does the right
120 # thing.
121 event4, context = self._create_duplicate_event(txn_id)
122 self.assertNotEqual(event1.event_id, event3.event_id)
123
124 events, _ = self.get_success(
125 self.persist_event_storage.persist_events([(event3, context)])
126 )
127 ret_event4 = events[0]
128
129 # Assert that the returned values match those from the initial event
130 # rather than the new one.
131 self.assertEqual(ret_event1.event_id, ret_event4.event_id)
132
133 def test_duplicated_txn_id_one_call(self):
134 """Test that we correctly handle duplicates that we try and persist at
135 the same time.
136 """
137
138 txn_id = "something_else_suitably_random"
139
140 # Create two duplicate events to persist at the same time
141 event1, context1 = self._create_duplicate_event(txn_id)
142 event2, context2 = self._create_duplicate_event(txn_id)
143
144 # Ensure their event IDs are different to start with
145 self.assertNotEqual(event1.event_id, event2.event_id)
146
147 events, _ = self.get_success(
148 self.persist_event_storage.persist_events(
149 [(event1, context1), (event2, context2)]
150 )
151 )
152
153 # Check that we've deduplicated the events.
154 self.assertEqual(len(events), 2)
155 self.assertEqual(events[0].event_id, events[1].event_id)
285285 h._validate_metadata,
286286 )
287287
288 # Tests for configs that the userinfo endpoint
288 # Tests for configs that require the userinfo endpoint
289289 self.assertFalse(h._uses_userinfo)
290 h._scopes = [] # do not request the openid scope
290 self.assertEqual(h._user_profile_method, "auto")
291 h._user_profile_method = "userinfo_endpoint"
292 self.assertTrue(h._uses_userinfo)
293
294 # Revert the profile method and do not request the "openid" scope.
295 h._user_profile_method = "auto"
296 h._scopes = []
291297 self.assertTrue(h._uses_userinfo)
292298 self.assertRaisesRegex(ValueError, "userinfo_endpoint", h._validate_metadata)
293299
469469 def prepare(self, reactor, clock, hs):
470470 self.federation_sender = hs.get_federation_sender()
471471 self.event_builder_factory = hs.get_event_builder_factory()
472 self.federation_handler = hs.get_handlers().federation_handler
472 self.federation_handler = hs.get_federation_handler()
473473 self.presence_handler = hs.get_presence_handler()
474474
475475 # self.event_builder_for_2 = EventBuilderFactory(hs)
614614 self.store.get_latest_event_ids_in_room(room_id)
615615 )
616616
617 event = self.get_success(builder.build(prev_event_ids))
617 event = self.get_success(builder.build(prev_event_ids, None))
618618
619619 self.get_success(self.federation_handler.on_receive_pdu(hostname, event))
620620
1919
2020 import synapse.types
2121 from synapse.api.errors import AuthError, SynapseError
22 from synapse.handlers.profile import MasterProfileHandler
2322 from synapse.types import UserID
2423
2524 from tests import unittest
2726 from tests.utils import setup_test_homeserver
2827
2928
30 class ProfileHandlers:
31 def __init__(self, hs):
32 self.profile_handler = MasterProfileHandler(hs)
33
34
3529 class ProfileTestCase(unittest.TestCase):
3630 """ Tests profile management. """
3731
5044 hs = yield setup_test_homeserver(
5145 self.addCleanup,
5246 http_client=None,
53 handlers=None,
5447 resource_for_federation=Mock(),
5548 federation_client=self.mock_federation,
5649 federation_server=Mock(),
1717 from synapse.api.auth import Auth
1818 from synapse.api.constants import UserTypes
1919 from synapse.api.errors import Codes, ResourceLimitError, SynapseError
20 from synapse.handlers.register import RegistrationHandler
2120 from synapse.spam_checker_api import RegistrationBehaviour
2221 from synapse.types import RoomAlias, UserID, create_requester
2322
2625 from tests.utils import mock_getRawHeaders
2726
2827 from .. import unittest
29
30
31 class RegistrationHandlers:
32 def __init__(self, hs):
33 self.registration_handler = RegistrationHandler(hs)
3428
3529
3630 class RegistrationTestCase(unittest.HomeserverTestCase):
153147 room_alias_str = "#room:test"
154148 user_id = self.get_success(self.handler.register_user(localpart="jeff"))
155149 rooms = self.get_success(self.store.get_rooms_for_user(user_id))
156 directory_handler = self.hs.get_handlers().directory_handler
150 directory_handler = self.hs.get_directory_handler()
157151 room_alias = RoomAlias.from_string(room_alias_str)
158152 room_id = self.get_success(directory_handler.get_association(room_alias))
159153
192186 user_id = self.get_success(self.handler.register_user(localpart="support"))
193187 rooms = self.get_success(self.store.get_rooms_for_user(user_id))
194188 self.assertEqual(len(rooms), 0)
195 directory_handler = self.hs.get_handlers().directory_handler
189 directory_handler = self.hs.get_directory_handler()
196190 room_alias = RoomAlias.from_string(room_alias_str)
197191 self.get_failure(directory_handler.get_association(room_alias), SynapseError)
198192
204198 self.store.is_real_user = Mock(return_value=make_awaitable(True))
205199 user_id = self.get_success(self.handler.register_user(localpart="real"))
206200 rooms = self.get_success(self.store.get_rooms_for_user(user_id))
207 directory_handler = self.hs.get_handlers().directory_handler
201 directory_handler = self.hs.get_directory_handler()
208202 room_alias = RoomAlias.from_string(room_alias_str)
209203 room_id = self.get_success(directory_handler.get_association(room_alias))
210204
236230 user_id = self.get_success(self.handler.register_user(localpart="jeff"))
237231
238232 # Ensure the room was created.
239 directory_handler = self.hs.get_handlers().directory_handler
233 directory_handler = self.hs.get_directory_handler()
240234 room_alias = RoomAlias.from_string(room_alias_str)
241235 room_id = self.get_success(directory_handler.get_association(room_alias))
242236
265259 user_id = self.get_success(self.handler.register_user(localpart="jeff"))
266260
267261 # Ensure the room was created.
268 directory_handler = self.hs.get_handlers().directory_handler
262 directory_handler = self.hs.get_directory_handler()
269263 room_alias = RoomAlias.from_string(room_alias_str)
270264 room_id = self.get_success(directory_handler.get_association(room_alias))
271265
303297 user_id = self.get_success(self.handler.register_user(localpart="jeff"))
304298
305299 # Ensure the room was created.
306 directory_handler = self.hs.get_handlers().directory_handler
300 directory_handler = self.hs.get_directory_handler()
307301 room_alias = RoomAlias.from_string(room_alias_str)
308302 room_id = self.get_success(directory_handler.get_association(room_alias))
309303
346340 )
347341
348342 # Ensure the room was created.
349 directory_handler = self.hs.get_handlers().directory_handler
343 directory_handler = self.hs.get_directory_handler()
350344 room_alias = RoomAlias.from_string(room_alias_str)
351345 room_id = self.get_success(directory_handler.get_association(room_alias))
352346
383377 user_id = self.get_success(self.handler.register_user(localpart="jeff"))
384378
385379 # Ensure the room was created.
386 directory_handler = self.hs.get_handlers().directory_handler
380 directory_handler = self.hs.get_directory_handler()
387381 room_alias = RoomAlias.from_string(room_alias_str)
388382 room_id = self.get_success(directory_handler.get_association(room_alias))
389383
412406 )
413407 )
414408 self.get_success(
415 event_creation_handler.send_nonmember_event(requester, event, context)
409 event_creation_handler.handle_new_client_event(requester, event, context)
416410 )
417411
418412 # Register a second user, which won't be be in the room (or even have an invite)
6464 mock_federation_client = Mock(spec=["put_json"])
6565 mock_federation_client.put_json.return_value = defer.succeed((200, "OK"))
6666
67 datastores = Mock()
68 datastores.main = Mock(
69 spec=[
70 # Bits that Federation needs
71 "prep_send_transaction",
72 "delivered_txn",
73 "get_received_txn_response",
74 "set_received_txn_response",
75 "get_destination_last_successful_stream_ordering",
76 "get_destination_retry_timings",
77 "get_devices_by_remote",
78 "maybe_store_room_on_invite",
79 # Bits that user_directory needs
80 "get_user_directory_stream_pos",
81 "get_current_state_deltas",
82 "get_device_updates_by_remote",
83 "get_room_max_stream_ordering",
84 ]
85 )
86
8767 # the tests assume that we are starting at unix time 1000
8868 reactor.pump((1000,))
8969
9373 keyring=mock_keyring,
9474 replication_streams={},
9575 )
96
97 hs.datastores = datastores
9876
9977 return hs
10078
11391 "retry_interval": 0,
11492 "failure_ts": None,
11593 }
116 self.datastore.get_destination_retry_timings.return_value = defer.succeed(
117 retry_timings_res
118 )
119
120 self.datastore.get_device_updates_by_remote.return_value = make_awaitable(
121 (0, [])
122 )
123
124 self.datastore.get_destination_last_successful_stream_ordering.return_value = make_awaitable(
125 None
94 self.datastore.get_destination_retry_timings = Mock(
95 return_value=defer.succeed(retry_timings_res)
96 )
97
98 self.datastore.get_device_updates_by_remote = Mock(
99 return_value=make_awaitable((0, []))
100 )
101
102 self.datastore.get_destination_last_successful_stream_ordering = Mock(
103 return_value=make_awaitable(None)
126104 )
127105
128106 def get_received_txn_response(*args):
144122
145123 self.datastore.get_joined_hosts_for_room = get_joined_hosts_for_room
146124
147 def get_users_in_room(room_id):
148 return defer.succeed({str(u) for u in self.room_members})
125 async def get_users_in_room(room_id):
126 return {str(u) for u in self.room_members}
149127
150128 self.datastore.get_users_in_room = get_users_in_room
151129
152 self.datastore.get_user_directory_stream_pos.side_effect = (
153 # we deliberately return a non-None stream pos to avoid doing an initial_spam
154 lambda: make_awaitable(1)
155 )
156
157 self.datastore.get_current_state_deltas.return_value = (0, None)
130 self.datastore.get_user_directory_stream_pos = Mock(
131 side_effect=(
132 # we deliberately return a non-None stream pos to avoid doing an initial_spam
133 lambda: make_awaitable(1)
134 )
135 )
136
137 self.datastore.get_current_state_deltas = Mock(return_value=(0, None))
158138
159139 self.datastore.get_to_device_stream_token = lambda: 0
160140 self.datastore.get_new_device_msgs_for_remote = lambda *args, **kargs: make_awaitable(
7777 "server_name",
7878 "name",
7979 ]
80 self.assertEqual(set(log.keys()), set(expected_log_keys))
80 self.assertCountEqual(log.keys(), expected_log_keys)
8181
8282 # It contains the data we expect.
8383 self.assertEqual(log["name"], "wally")
1111 # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
1212 # See the License for the specific language governing permissions and
1313 # limitations under the License.
14
15 from synapse.module_api import ModuleApi
14 from mock import Mock
15
16 from synapse.events import EventBase
17 from synapse.rest import admin
18 from synapse.rest.client.v1 import login, room
19 from synapse.types import create_requester
1620
1721 from tests.unittest import HomeserverTestCase
1822
1923
2024 class ModuleApiTestCase(HomeserverTestCase):
25 servlets = [
26 admin.register_servlets,
27 login.register_servlets,
28 room.register_servlets,
29 ]
30
2131 def prepare(self, reactor, clock, homeserver):
2232 self.store = homeserver.get_datastore()
23 self.module_api = ModuleApi(homeserver, homeserver.get_auth_handler())
33 self.module_api = homeserver.get_module_api()
34 self.event_creation_handler = homeserver.get_event_creation_handler()
2435
2536 def test_can_register_user(self):
2637 """Tests that an external module can register a user"""
5162 # Check that the displayname was assigned
5263 displayname = self.get_success(self.store.get_profile_displayname("bob"))
5364 self.assertEqual(displayname, "Bobberino")
65
66 def test_sending_events_into_room(self):
67 """Tests that a module can send events into a room"""
68 # Mock out create_and_send_nonmember_event to check whether events are being sent
69 self.event_creation_handler.create_and_send_nonmember_event = Mock(
70 spec=[],
71 side_effect=self.event_creation_handler.create_and_send_nonmember_event,
72 )
73
74 # Create a user and room to play with
75 user_id = self.register_user("summer", "monkey")
76 tok = self.login("summer", "monkey")
77 room_id = self.helper.create_room_as(user_id, tok=tok)
78
79 # Create and send a non-state event
80 content = {"body": "I am a puppet", "msgtype": "m.text"}
81 event_dict = {
82 "room_id": room_id,
83 "type": "m.room.message",
84 "content": content,
85 "sender": user_id,
86 }
87 event = self.get_success(
88 self.module_api.create_and_send_event_into_room(event_dict)
89 ) # type: EventBase
90 self.assertEqual(event.sender, user_id)
91 self.assertEqual(event.type, "m.room.message")
92 self.assertEqual(event.room_id, room_id)
93 self.assertFalse(hasattr(event, "state_key"))
94 self.assertDictEqual(event.content, content)
95
96 # Check that the event was sent
97 self.event_creation_handler.create_and_send_nonmember_event.assert_called_with(
98 create_requester(user_id),
99 event_dict,
100 ratelimit=False,
101 ignore_shadow_ban=True,
102 )
103
104 # Create and send a state event
105 content = {
106 "events_default": 0,
107 "users": {user_id: 100},
108 "state_default": 50,
109 "users_default": 0,
110 "events": {"test.event.type": 25},
111 }
112 event_dict = {
113 "room_id": room_id,
114 "type": "m.room.power_levels",
115 "content": content,
116 "sender": user_id,
117 "state_key": "",
118 }
119 event = self.get_success(
120 self.module_api.create_and_send_event_into_room(event_dict)
121 ) # type: EventBase
122 self.assertEqual(event.sender, user_id)
123 self.assertEqual(event.type, "m.room.power_levels")
124 self.assertEqual(event.room_id, room_id)
125 self.assertEqual(event.state_key, "")
126 self.assertDictEqual(event.content, content)
127
128 # Check that the event was sent
129 self.event_creation_handler.create_and_send_nonmember_event.assert_called_with(
130 create_requester(user_id),
131 {
132 "type": "m.room.power_levels",
133 "content": content,
134 "room_id": room_id,
135 "sender": user_id,
136 "state_key": "",
137 },
138 ratelimit=False,
139 ignore_shadow_ban=True,
140 )
141
142 # Check that we can't send membership events
143 content = {
144 "membership": "leave",
145 }
146 event_dict = {
147 "room_id": room_id,
148 "type": "m.room.member",
149 "content": content,
150 "sender": user_id,
151 "state_key": user_id,
152 }
153 self.get_failure(
154 self.module_api.create_and_send_event_into_room(event_dict), Exception
155 )
156
157 def test_public_rooms(self):
158 """Tests that a room can be added and removed from the public rooms list,
159 as well as have its public rooms directory state queried.
160 """
161 # Create a user and room to play with
162 user_id = self.register_user("kermit", "monkey")
163 tok = self.login("kermit", "monkey")
164 room_id = self.helper.create_room_as(user_id, tok=tok)
165
166 # The room should not currently be in the public rooms directory
167 is_in_public_rooms = self.get_success(
168 self.module_api.public_room_list_manager.room_is_in_public_room_list(
169 room_id
170 )
171 )
172 self.assertFalse(is_in_public_rooms)
173
174 # Let's try adding it to the public rooms directory
175 self.get_success(
176 self.module_api.public_room_list_manager.add_room_to_public_room_list(
177 room_id
178 )
179 )
180
181 # And checking whether it's in there...
182 is_in_public_rooms = self.get_success(
183 self.module_api.public_room_list_manager.room_is_in_public_room_list(
184 room_id
185 )
186 )
187 self.assertTrue(is_in_public_rooms)
188
189 # Let's remove it again
190 self.get_success(
191 self.module_api.public_room_list_manager.remove_room_from_public_room_list(
192 room_id
193 )
194 )
195
196 # Should be gone
197 is_in_public_rooms = self.get_success(
198 self.module_api.public_room_list_manager.room_is_in_public_room_list(
199 room_id
200 )
201 )
202 self.assertFalse(is_in_public_rooms)
157157 # We should get emailed about those messages
158158 self._check_for_mail()
159159
160 def test_encrypted_message(self):
161 room = self.helper.create_room_as(self.user_id, tok=self.access_token)
162 self.helper.invite(
163 room=room, src=self.user_id, tok=self.access_token, targ=self.others[0].id
164 )
165 self.helper.join(room=room, user=self.others[0].id, tok=self.others[0].token)
166
167 # The other user sends some messages
168 self.helper.send_event(room, "m.room.encrypted", {}, tok=self.others[0].token)
169
170 # We should get emailed about that message
171 self._check_for_mail()
172
160173 def _check_for_mail(self):
161 "Check that the user receives an email notification"
174 """Check that the user receives an email notification"""
162175
163176 # Get the stream ordering before it gets sent
164177 pushers = self.get_success(
1111 # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
1212 # See the License for the specific language governing permissions and
1313 # limitations under the License.
14
1514 import logging
1615 from typing import Any, Callable, List, Optional, Tuple
1716
1817 import attr
18 import hiredis
1919
2020 from twisted.internet.interfaces import IConsumer, IPullProducer, IReactorTime
21 from twisted.internet.protocol import Protocol
2122 from twisted.internet.task import LoopingCall
2223 from twisted.web.http import HTTPChannel
2324
2627 GenericWorkerServer,
2728 )
2829 from synapse.http.server import JsonResource
29 from synapse.http.site import SynapseRequest
30 from synapse.http.site import SynapseRequest, SynapseSite
3031 from synapse.replication.http import ReplicationRestResource, streams
3132 from synapse.replication.tcp.handler import ReplicationCommandHandler
3233 from synapse.replication.tcp.protocol import ClientReplicationStreamProtocol
5758 self.reactor.lookups["testserv"] = "1.2.3.4"
5859 self.worker_hs = self.setup_test_homeserver(
5960 http_client=None,
60 homeserverToUse=GenericWorkerServer,
61 homeserver_to_use=GenericWorkerServer,
6162 config=self._get_worker_hs_config(),
6263 reactor=self.reactor,
6364 )
196197 self.server_factory = ReplicationStreamProtocolFactory(self.hs)
197198 self.streamer = self.hs.get_replication_streamer()
198199
200 # Fake in memory Redis server that servers can connect to.
201 self._redis_server = FakeRedisPubSubServer()
202
199203 store = self.hs.get_datastore()
200204 self.database_pool = store.db_pool
201205
202206 self.reactor.lookups["testserv"] = "1.2.3.4"
203
204 self._worker_hs_to_resource = {}
207 self.reactor.lookups["localhost"] = "127.0.0.1"
208
209 # A map from a HS instance to the associated HTTP Site to use for
210 # handling inbound HTTP requests to that instance.
211 self._hs_to_site = {self.hs: self.site}
212
213 if self.hs.config.redis.redis_enabled:
214 # Handle attempts to connect to fake redis server.
215 self.reactor.add_tcp_client_callback(
216 "localhost", 6379, self.connect_any_redis_attempts,
217 )
218
219 self.hs.get_tcp_replication().start_replication(self.hs)
205220
206221 # When we see a connection attempt to the master replication listener we
207222 # automatically set up the connection. This is so that tests don't
208223 # manually have to go and explicitly set it up each time (plus sometimes
209224 # it is impossible to write the handling explicitly in the tests).
225 #
226 # Register the master replication listener:
210227 self.reactor.add_tcp_client_callback(
211 "1.2.3.4", 8765, self._handle_http_replication_attempt
228 "1.2.3.4",
229 8765,
230 lambda: self._handle_http_replication_attempt(self.hs, 8765),
212231 )
213232
214233 def create_test_json_resource(self):
246265 config.update(extra_config)
247266
248267 worker_hs = self.setup_test_homeserver(
249 homeserverToUse=GenericWorkerServer,
268 homeserver_to_use=GenericWorkerServer,
250269 config=config,
251270 reactor=self.reactor,
252271 **kwargs
253272 )
254273
274 # If the instance is in the `instance_map` config then workers may try
275 # and send HTTP requests to it, so we register it with
276 # `_handle_http_replication_attempt` like we do with the master HS.
277 instance_name = worker_hs.get_instance_name()
278 instance_loc = worker_hs.config.worker.instance_map.get(instance_name)
279 if instance_loc:
280 # Ensure the host is one that has a fake DNS entry.
281 if instance_loc.host not in self.reactor.lookups:
282 raise Exception(
283 "Host does not have an IP for instance_map[%r].host = %r"
284 % (instance_name, instance_loc.host,)
285 )
286
287 self.reactor.add_tcp_client_callback(
288 self.reactor.lookups[instance_loc.host],
289 instance_loc.port,
290 lambda: self._handle_http_replication_attempt(
291 worker_hs, instance_loc.port
292 ),
293 )
294
255295 store = worker_hs.get_datastore()
256296 store.db_pool._db_pool = self.database_pool._db_pool
257297
258 repl_handler = ReplicationCommandHandler(worker_hs)
259 client = ClientReplicationStreamProtocol(
260 worker_hs, "client", "test", self.clock, repl_handler,
261 )
262 server = self.server_factory.buildProtocol(None)
263
264 client_transport = FakeTransport(server, self.reactor)
265 client.makeConnection(client_transport)
266
267 server_transport = FakeTransport(client, self.reactor)
268 server.makeConnection(server_transport)
298 # Set up TCP replication between master and the new worker if we don't
299 # have Redis support enabled.
300 if not worker_hs.config.redis_enabled:
301 repl_handler = ReplicationCommandHandler(worker_hs)
302 client = ClientReplicationStreamProtocol(
303 worker_hs, "client", "test", self.clock, repl_handler,
304 )
305 server = self.server_factory.buildProtocol(None)
306
307 client_transport = FakeTransport(server, self.reactor)
308 client.makeConnection(client_transport)
309
310 server_transport = FakeTransport(client, self.reactor)
311 server.makeConnection(server_transport)
269312
270313 # Set up a resource for the worker
271 resource = ReplicationRestResource(self.hs)
314 resource = ReplicationRestResource(worker_hs)
272315
273316 for servlet in self.servlets:
274317 servlet(worker_hs, resource)
275318
276 self._worker_hs_to_resource[worker_hs] = resource
319 self._hs_to_site[worker_hs] = SynapseSite(
320 logger_name="synapse.access.http.fake",
321 site_tag="{}-{}".format(
322 worker_hs.config.server.server_name, worker_hs.get_instance_name()
323 ),
324 config=worker_hs.config.server.listeners[0],
325 resource=resource,
326 server_version_string="1",
327 )
328
329 if worker_hs.config.redis.redis_enabled:
330 worker_hs.get_tcp_replication().start_replication(worker_hs)
277331
278332 return worker_hs
279333
284338 return config
285339
286340 def render_on_worker(self, worker_hs: HomeServer, request: SynapseRequest):
287 render(request, self._worker_hs_to_resource[worker_hs], self.reactor)
341 render(request, self._hs_to_site[worker_hs].resource, self.reactor)
288342
289343 def replicate(self):
290344 """Tell the master side of replication that something has happened, and then
293347 self.streamer.on_notifier_poke()
294348 self.pump()
295349
296 def _handle_http_replication_attempt(self):
297 """Handles a connection attempt to the master replication HTTP
298 listener.
350 def _handle_http_replication_attempt(self, hs, repl_port):
351 """Handles a connection attempt to the given HS replication HTTP
352 listener on the given port.
299353 """
300354
301355 # We should have at least one outbound connection attempt, where the
304358 self.assertGreaterEqual(len(clients), 1)
305359 (host, port, client_factory, _timeout, _bindAddress) = clients.pop()
306360 self.assertEqual(host, "1.2.3.4")
307 self.assertEqual(port, 8765)
361 self.assertEqual(port, repl_port)
308362
309363 # Set up client side protocol
310364 client_protocol = client_factory.buildProtocol(None)
314368 # Set up the server side protocol
315369 channel = _PushHTTPChannel(self.reactor)
316370 channel.requestFactory = request_factory
317 channel.site = self.site
371 channel.site = self._hs_to_site[hs]
318372
319373 # Connect client to server and vice versa.
320374 client_to_server_transport = FakeTransport(
331385 # before the data starts flowing over the connections as this is called
332386 # inside `connecTCP` before the connection has been passed back to the
333387 # code that requested the TCP connection.
388
389 def connect_any_redis_attempts(self):
390 """If redis is enabled we need to deal with workers connecting to a
391 redis server. We don't want to use a real Redis server so we use a
392 fake one.
393 """
394 clients = self.reactor.tcpClients
395 self.assertEqual(len(clients), 1)
396 (host, port, client_factory, _timeout, _bindAddress) = clients.pop(0)
397 self.assertEqual(host, "localhost")
398 self.assertEqual(port, 6379)
399
400 client_protocol = client_factory.buildProtocol(None)
401 server_protocol = self._redis_server.buildProtocol(None)
402
403 client_to_server_transport = FakeTransport(
404 server_protocol, self.reactor, client_protocol
405 )
406 client_protocol.makeConnection(client_to_server_transport)
407
408 server_to_client_transport = FakeTransport(
409 client_protocol, self.reactor, server_protocol
410 )
411 server_protocol.makeConnection(server_to_client_transport)
412
413 return client_to_server_transport, server_to_client_transport
334414
335415
336416 class TestReplicationDataHandler(GenericWorkerReplicationHandler):
466546 pass
467547
468548 self.stopProducing()
549
550
551 class FakeRedisPubSubServer:
552 """A fake Redis server for pub/sub.
553 """
554
555 def __init__(self):
556 self._subscribers = set()
557
558 def add_subscriber(self, conn):
559 """A connection has called SUBSCRIBE
560 """
561 self._subscribers.add(conn)
562
563 def remove_subscriber(self, conn):
564 """A connection has called UNSUBSCRIBE
565 """
566 self._subscribers.discard(conn)
567
568 def publish(self, conn, channel, msg) -> int:
569 """A connection want to publish a message to subscribers.
570 """
571 for sub in self._subscribers:
572 sub.send(["message", channel, msg])
573
574 return len(self._subscribers)
575
576 def buildProtocol(self, addr):
577 return FakeRedisPubSubProtocol(self)
578
579
580 class FakeRedisPubSubProtocol(Protocol):
581 """A connection from a client talking to the fake Redis server.
582 """
583
584 def __init__(self, server: FakeRedisPubSubServer):
585 self._server = server
586 self._reader = hiredis.Reader()
587
588 def dataReceived(self, data):
589 self._reader.feed(data)
590
591 # We might get multiple messages in one packet.
592 while True:
593 msg = self._reader.gets()
594
595 if msg is False:
596 # No more messages.
597 return
598
599 if not isinstance(msg, list):
600 # Inbound commands should always be a list
601 raise Exception("Expected redis list")
602
603 self.handle_command(msg[0], *msg[1:])
604
605 def handle_command(self, command, *args):
606 """Received a Redis command from the client.
607 """
608
609 # We currently only support pub/sub.
610 if command == b"PUBLISH":
611 channel, message = args
612 num_subscribers = self._server.publish(self, channel, message)
613 self.send(num_subscribers)
614 elif command == b"SUBSCRIBE":
615 (channel,) = args
616 self._server.add_subscriber(self)
617 self.send(["subscribe", channel, 1])
618 else:
619 raise Exception("Unknown command")
620
621 def send(self, msg):
622 """Send a message back to the client.
623 """
624 raw = self.encode(msg).encode("utf-8")
625
626 self.transport.write(raw)
627 self.transport.flush()
628
629 def encode(self, obj):
630 """Encode an object to its Redis format.
631
632 Supports: strings/bytes, integers and list/tuples.
633 """
634
635 if isinstance(obj, bytes):
636 # We assume bytes are just unicode strings.
637 obj = obj.decode("utf-8")
638
639 if isinstance(obj, str):
640 return "${len}\r\n{str}\r\n".format(len=len(obj), str=obj)
641 if isinstance(obj, int):
642 return ":{val}\r\n".format(val=obj)
643 if isinstance(obj, (list, tuple)):
644 items = "".join(self.encode(a) for a in obj)
645 return "*{len}\r\n{items}".format(len=len(obj), items=items)
646
647 raise Exception("Unrecognized type for encoding redis: %r: %r", type(obj), obj)
648
649 def connectionLost(self, reason):
650 self._server.remove_subscriber(self)
3030 return config
3131
3232 def make_homeserver(self, reactor, clock):
33 hs = self.setup_test_homeserver(homeserverToUse=GenericWorkerServer)
33 hs = self.setup_test_homeserver(homeserver_to_use=GenericWorkerServer)
3434
3535 return hs
3636
206206 def create_room_with_remote_server(self, user, token, remote_server="other_server"):
207207 room = self.helper.create_room_as(user, tok=token)
208208 store = self.hs.get_datastore()
209 federation = self.hs.get_handlers().federation_handler
209 federation = self.hs.get_federation_handler()
210210
211211 prev_event_ids = self.get_success(store.get_latest_event_ids_in_room(room))
212212 room_version = self.get_success(store.get_room_version(room))
225225 }
226226
227227 builder = factory.for_room_version(room_version, event_dict)
228 join_event = self.get_success(builder.build(prev_event_ids))
228 join_event = self.get_success(builder.build(prev_event_ids, None))
229229
230230 self.get_success(federation.on_send_join_request(remote_server, join_event))
231231 self.replicate()
0 # -*- coding: utf-8 -*-
1 # Copyright 2020 The Matrix.org Foundation C.I.C.
2 #
3 # Licensed under the Apache License, Version 2.0 (the "License");
4 # you may not use this file except in compliance with the License.
5 # You may obtain a copy of the License at
6 #
7 # http://www.apache.org/licenses/LICENSE-2.0
8 #
9 # Unless required by applicable law or agreed to in writing, software
10 # distributed under the License is distributed on an "AS IS" BASIS,
11 # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 # See the License for the specific language governing permissions and
13 # limitations under the License.
14 import logging
15
16 from mock import patch
17
18 from synapse.api.room_versions import RoomVersion
19 from synapse.rest import admin
20 from synapse.rest.client.v1 import login, room
21 from synapse.rest.client.v2_alpha import sync
22
23 from tests.replication._base import BaseMultiWorkerStreamTestCase
24 from tests.utils import USE_POSTGRES_FOR_TESTS
25
26 logger = logging.getLogger(__name__)
27
28
29 class EventPersisterShardTestCase(BaseMultiWorkerStreamTestCase):
30 """Checks event persisting sharding works
31 """
32
33 # Event persister sharding requires postgres (due to needing
34 # `MutliWriterIdGenerator`).
35 if not USE_POSTGRES_FOR_TESTS:
36 skip = "Requires Postgres"
37
38 servlets = [
39 admin.register_servlets_for_client_rest_resource,
40 room.register_servlets,
41 login.register_servlets,
42 sync.register_servlets,
43 ]
44
45 def prepare(self, reactor, clock, hs):
46 # Register a user who sends a message that we'll get notified about
47 self.other_user_id = self.register_user("otheruser", "pass")
48 self.other_access_token = self.login("otheruser", "pass")
49
50 self.room_creator = self.hs.get_room_creation_handler()
51 self.store = hs.get_datastore()
52
53 def default_config(self):
54 conf = super().default_config()
55 conf["redis"] = {"enabled": "true"}
56 conf["stream_writers"] = {"events": ["worker1", "worker2"]}
57 conf["instance_map"] = {
58 "worker1": {"host": "testserv", "port": 1001},
59 "worker2": {"host": "testserv", "port": 1002},
60 }
61 return conf
62
63 def _create_room(self, room_id: str, user_id: str, tok: str):
64 """Create a room with given room_id
65 """
66
67 # We control the room ID generation by patching out the
68 # `_generate_room_id` method
69 async def generate_room(
70 creator_id: str, is_public: bool, room_version: RoomVersion
71 ):
72 await self.store.store_room(
73 room_id=room_id,
74 room_creator_user_id=creator_id,
75 is_public=is_public,
76 room_version=room_version,
77 )
78 return room_id
79
80 with patch(
81 "synapse.handlers.room.RoomCreationHandler._generate_room_id"
82 ) as mock:
83 mock.side_effect = generate_room
84 self.helper.create_room_as(user_id, tok=tok)
85
86 def test_basic(self):
87 """Simple test to ensure that multiple rooms can be created and joined,
88 and that different rooms get handled by different instances.
89 """
90
91 self.make_worker_hs(
92 "synapse.app.generic_worker", {"worker_name": "worker1"},
93 )
94
95 self.make_worker_hs(
96 "synapse.app.generic_worker", {"worker_name": "worker2"},
97 )
98
99 persisted_on_1 = False
100 persisted_on_2 = False
101
102 store = self.hs.get_datastore()
103
104 user_id = self.register_user("user", "pass")
105 access_token = self.login("user", "pass")
106
107 # Keep making new rooms until we see rooms being persisted on both
108 # workers.
109 for _ in range(10):
110 # Create a room
111 room = self.helper.create_room_as(user_id, tok=access_token)
112
113 # The other user joins
114 self.helper.join(
115 room=room, user=self.other_user_id, tok=self.other_access_token
116 )
117
118 # The other user sends some messages
119 rseponse = self.helper.send(room, body="Hi!", tok=self.other_access_token)
120 event_id = rseponse["event_id"]
121
122 # The event position includes which instance persisted the event.
123 pos = self.get_success(store.get_position_for_event(event_id))
124
125 persisted_on_1 |= pos.instance_name == "worker1"
126 persisted_on_2 |= pos.instance_name == "worker2"
127
128 if persisted_on_1 and persisted_on_2:
129 break
130
131 self.assertTrue(persisted_on_1)
132 self.assertTrue(persisted_on_2)
133
134 def test_vector_clock_token(self):
135 """Tests that using a stream token with a vector clock component works
136 correctly with basic /sync and /messages usage.
137 """
138
139 self.make_worker_hs(
140 "synapse.app.generic_worker", {"worker_name": "worker1"},
141 )
142
143 worker_hs2 = self.make_worker_hs(
144 "synapse.app.generic_worker", {"worker_name": "worker2"},
145 )
146
147 sync_hs = self.make_worker_hs(
148 "synapse.app.generic_worker", {"worker_name": "sync"},
149 )
150
151 # Specially selected room IDs that get persisted on different workers.
152 room_id1 = "!foo:test"
153 room_id2 = "!baz:test"
154
155 self.assertEqual(
156 self.hs.config.worker.events_shard_config.get_instance(room_id1), "worker1"
157 )
158 self.assertEqual(
159 self.hs.config.worker.events_shard_config.get_instance(room_id2), "worker2"
160 )
161
162 user_id = self.register_user("user", "pass")
163 access_token = self.login("user", "pass")
164
165 store = self.hs.get_datastore()
166
167 # Create two room on the different workers.
168 self._create_room(room_id1, user_id, access_token)
169 self._create_room(room_id2, user_id, access_token)
170
171 # The other user joins
172 self.helper.join(
173 room=room_id1, user=self.other_user_id, tok=self.other_access_token
174 )
175 self.helper.join(
176 room=room_id2, user=self.other_user_id, tok=self.other_access_token
177 )
178
179 # Do an initial sync so that we're up to date.
180 request, channel = self.make_request("GET", "/sync", access_token=access_token)
181 self.render_on_worker(sync_hs, request)
182 next_batch = channel.json_body["next_batch"]
183
184 # We now gut wrench into the events stream MultiWriterIdGenerator on
185 # worker2 to mimic it getting stuck persisting an event. This ensures
186 # that when we send an event on worker1 we end up in a state where
187 # worker2 events stream position lags that on worker1, resulting in a
188 # RoomStreamToken with a non-empty instance map component.
189 #
190 # Worker2's event stream position will not advance until we call
191 # __aexit__ again.
192 actx = worker_hs2.get_datastore()._stream_id_gen.get_next()
193 self.get_success(actx.__aenter__())
194
195 response = self.helper.send(room_id1, body="Hi!", tok=self.other_access_token)
196 first_event_in_room1 = response["event_id"]
197
198 # Assert that the current stream token has an instance map component, as
199 # we are trying to test vector clock tokens.
200 room_stream_token = store.get_room_max_token()
201 self.assertNotEqual(len(room_stream_token.instance_map), 0)
202
203 # Check that syncing still gets the new event, despite the gap in the
204 # stream IDs.
205 request, channel = self.make_request(
206 "GET", "/sync?since={}".format(next_batch), access_token=access_token
207 )
208 self.render_on_worker(sync_hs, request)
209
210 # We should only see the new event and nothing else
211 self.assertIn(room_id1, channel.json_body["rooms"]["join"])
212 self.assertNotIn(room_id2, channel.json_body["rooms"]["join"])
213
214 events = channel.json_body["rooms"]["join"][room_id1]["timeline"]["events"]
215 self.assertListEqual(
216 [first_event_in_room1], [event["event_id"] for event in events]
217 )
218
219 # Get the next batch and makes sure its a vector clock style token.
220 vector_clock_token = channel.json_body["next_batch"]
221 self.assertTrue(vector_clock_token.startswith("m"))
222
223 # Now that we've got a vector clock token we finish the fake persisting
224 # an event we started above.
225 self.get_success(actx.__aexit__(None, None, None))
226
227 # Now try and send an event to the other rooom so that we can test that
228 # the vector clock style token works as a `since` token.
229 response = self.helper.send(room_id2, body="Hi!", tok=self.other_access_token)
230 first_event_in_room2 = response["event_id"]
231
232 request, channel = self.make_request(
233 "GET",
234 "/sync?since={}".format(vector_clock_token),
235 access_token=access_token,
236 )
237 self.render_on_worker(sync_hs, request)
238
239 self.assertNotIn(room_id1, channel.json_body["rooms"]["join"])
240 self.assertIn(room_id2, channel.json_body["rooms"]["join"])
241
242 events = channel.json_body["rooms"]["join"][room_id2]["timeline"]["events"]
243 self.assertListEqual(
244 [first_event_in_room2], [event["event_id"] for event in events]
245 )
246
247 next_batch = channel.json_body["next_batch"]
248
249 # We also want to test that the vector clock style token works with
250 # pagination. We do this by sending a couple of new events into the room
251 # and syncing again to get a prev_batch token for each room, then
252 # paginating from there back to the vector clock token.
253 self.helper.send(room_id1, body="Hi again!", tok=self.other_access_token)
254 self.helper.send(room_id2, body="Hi again!", tok=self.other_access_token)
255
256 request, channel = self.make_request(
257 "GET", "/sync?since={}".format(next_batch), access_token=access_token
258 )
259 self.render_on_worker(sync_hs, request)
260
261 prev_batch1 = channel.json_body["rooms"]["join"][room_id1]["timeline"][
262 "prev_batch"
263 ]
264 prev_batch2 = channel.json_body["rooms"]["join"][room_id2]["timeline"][
265 "prev_batch"
266 ]
267
268 # Paginating back in the first room should not produce any results, as
269 # no events have happened in it. This tests that we are correctly
270 # filtering results based on the vector clock portion.
271 request, channel = self.make_request(
272 "GET",
273 "/rooms/{}/messages?from={}&to={}&dir=b".format(
274 room_id1, prev_batch1, vector_clock_token
275 ),
276 access_token=access_token,
277 )
278 self.render_on_worker(sync_hs, request)
279 self.assertListEqual([], channel.json_body["chunk"])
280
281 # Paginating back on the second room should produce the first event
282 # again. This tests that pagination isn't completely broken.
283 request, channel = self.make_request(
284 "GET",
285 "/rooms/{}/messages?from={}&to={}&dir=b".format(
286 room_id2, prev_batch2, vector_clock_token
287 ),
288 access_token=access_token,
289 )
290 self.render_on_worker(sync_hs, request)
291 self.assertEqual(len(channel.json_body["chunk"]), 1)
292 self.assertEqual(
293 channel.json_body["chunk"][0]["event_id"], first_event_in_room2
294 )
295
296 # Paginating forwards should give the same results
297 request, channel = self.make_request(
298 "GET",
299 "/rooms/{}/messages?from={}&to={}&dir=f".format(
300 room_id1, vector_clock_token, prev_batch1
301 ),
302 access_token=access_token,
303 )
304 self.render_on_worker(sync_hs, request)
305 self.assertListEqual([], channel.json_body["chunk"])
306
307 request, channel = self.make_request(
308 "GET",
309 "/rooms/{}/messages?from={}&to={}&dir=f".format(
310 room_id2, vector_clock_token, prev_batch2,
311 ),
312 access_token=access_token,
313 )
314 self.render_on_worker(sync_hs, request)
315 self.assertEqual(len(channel.json_body["chunk"]), 1)
316 self.assertEqual(
317 channel.json_body["chunk"][0]["event_id"], first_event_in_room2
318 )
7777
7878 def test_invite_3pid(self):
7979 """Ensure that a 3PID invite does not attempt to contact the identity server."""
80 identity_handler = self.hs.get_handlers().identity_handler
80 identity_handler = self.hs.get_identity_handler()
8181 identity_handler.lookup_3pid = Mock(
8282 side_effect=AssertionError("This should not get called")
8383 )
0 # -*- coding: utf-8 -*-
1 # Copyright 2019 The Matrix.org Foundation C.I.C.
2 #
3 # Licensed under the Apache License, Version 2.0 (the 'License');
4 # you may not use this file except in compliance with the License.
5 # You may obtain a copy of the License at
6 #
7 # http://www.apache.org/licenses/LICENSE-2.0
8 #
9 # Unless required by applicable law or agreed to in writing, software
10 # distributed under the License is distributed on an 'AS IS' BASIS,
11 # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 # See the License for the specific language governing permissions and
13 # limitations under the License.
14 import threading
15 from typing import Dict
16
17 from mock import Mock
18
19 from synapse.events import EventBase
20 from synapse.module_api import ModuleApi
21 from synapse.rest import admin
22 from synapse.rest.client.v1 import login, room
23 from synapse.types import Requester, StateMap
24
25 from tests import unittest
26
27 thread_local = threading.local()
28
29
30 class ThirdPartyRulesTestModule:
31 def __init__(self, config: Dict, module_api: ModuleApi):
32 # keep a record of the "current" rules module, so that the test can patch
33 # it if desired.
34 thread_local.rules_module = self
35 self.module_api = module_api
36
37 async def on_create_room(
38 self, requester: Requester, config: dict, is_requester_admin: bool
39 ):
40 return True
41
42 async def check_event_allowed(self, event: EventBase, state: StateMap[EventBase]):
43 return True
44
45 @staticmethod
46 def parse_config(config):
47 return config
48
49
50 def current_rules_module() -> ThirdPartyRulesTestModule:
51 return thread_local.rules_module
52
53
54 class ThirdPartyRulesTestCase(unittest.HomeserverTestCase):
55 servlets = [
56 admin.register_servlets,
57 login.register_servlets,
58 room.register_servlets,
59 ]
60
61 def default_config(self):
62 config = super().default_config()
63 config["third_party_event_rules"] = {
64 "module": __name__ + ".ThirdPartyRulesTestModule",
65 "config": {},
66 }
67 return config
68
69 def prepare(self, reactor, clock, homeserver):
70 # Create a user and room to play with during the tests
71 self.user_id = self.register_user("kermit", "monkey")
72 self.tok = self.login("kermit", "monkey")
73
74 self.room_id = self.helper.create_room_as(self.user_id, tok=self.tok)
75
76 def test_third_party_rules(self):
77 """Tests that a forbidden event is forbidden from being sent, but an allowed one
78 can be sent.
79 """
80 # patch the rules module with a Mock which will return False for some event
81 # types
82 async def check(ev, state):
83 return ev.type != "foo.bar.forbidden"
84
85 callback = Mock(spec=[], side_effect=check)
86 current_rules_module().check_event_allowed = callback
87
88 request, channel = self.make_request(
89 "PUT",
90 "/_matrix/client/r0/rooms/%s/send/foo.bar.allowed/1" % self.room_id,
91 {},
92 access_token=self.tok,
93 )
94 self.render(request)
95 self.assertEquals(channel.result["code"], b"200", channel.result)
96
97 callback.assert_called_once()
98
99 # there should be various state events in the state arg: do some basic checks
100 state_arg = callback.call_args[0][1]
101 for k in (("m.room.create", ""), ("m.room.member", self.user_id)):
102 self.assertIn(k, state_arg)
103 ev = state_arg[k]
104 self.assertEqual(ev.type, k[0])
105 self.assertEqual(ev.state_key, k[1])
106
107 request, channel = self.make_request(
108 "PUT",
109 "/_matrix/client/r0/rooms/%s/send/foo.bar.forbidden/2" % self.room_id,
110 {},
111 access_token=self.tok,
112 )
113 self.render(request)
114 self.assertEquals(channel.result["code"], b"403", channel.result)
115
116 def test_cannot_modify_event(self):
117 """cannot accidentally modify an event before it is persisted"""
118
119 # first patch the event checker so that it will try to modify the event
120 async def check(ev: EventBase, state):
121 ev.content = {"x": "y"}
122 return True
123
124 current_rules_module().check_event_allowed = check
125
126 # now send the event
127 request, channel = self.make_request(
128 "PUT",
129 "/_matrix/client/r0/rooms/%s/send/modifyme/1" % self.room_id,
130 {"x": "x"},
131 access_token=self.tok,
132 )
133 self.render(request)
134 self.assertEqual(channel.result["code"], b"500", channel.result)
135
136 def test_modify_event(self):
137 """The module can return a modified version of the event"""
138 # first patch the event checker so that it will modify the event
139 async def check(ev: EventBase, state):
140 d = ev.get_dict()
141 d["content"] = {"x": "y"}
142 return d
143
144 current_rules_module().check_event_allowed = check
145
146 # now send the event
147 request, channel = self.make_request(
148 "PUT",
149 "/_matrix/client/r0/rooms/%s/send/modifyme/1" % self.room_id,
150 {"x": "x"},
151 access_token=self.tok,
152 )
153 self.render(request)
154 self.assertEqual(channel.result["code"], b"200", channel.result)
155 event_id = channel.json_body["event_id"]
156
157 # ... and check that it got modified
158 request, channel = self.make_request(
159 "GET",
160 "/_matrix/client/r0/rooms/%s/event/%s" % (self.room_id, event_id),
161 access_token=self.tok,
162 )
163 self.render(request)
164 self.assertEqual(channel.result["code"], b"200", channel.result)
165 ev = channel.json_body
166 self.assertEqual(ev["content"]["x"], "y")
167
168 def test_send_event(self):
169 """Tests that the module can send an event into a room via the module api"""
170 content = {
171 "msgtype": "m.text",
172 "body": "Hello!",
173 }
174 event_dict = {
175 "room_id": self.room_id,
176 "type": "m.room.message",
177 "content": content,
178 "sender": self.user_id,
179 }
180 event = self.get_success(
181 current_rules_module().module_api.create_and_send_event_into_room(
182 event_dict
183 )
184 ) # type: EventBase
185
186 self.assertEquals(event.sender, self.user_id)
187 self.assertEquals(event.room_id, self.room_id)
188 self.assertEquals(event.type, "m.room.message")
189 self.assertEquals(event.content, content)
+0
-79
tests/rest/client/third_party_rules.py less more
0 # -*- coding: utf-8 -*-
1 # Copyright 2019 The Matrix.org Foundation C.I.C.
2 #
3 # Licensed under the Apache License, Version 2.0 (the 'License');
4 # you may not use this file except in compliance with the License.
5 # You may obtain a copy of the License at
6 #
7 # http://www.apache.org/licenses/LICENSE-2.0
8 #
9 # Unless required by applicable law or agreed to in writing, software
10 # distributed under the License is distributed on an 'AS IS' BASIS,
11 # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 # See the License for the specific language governing permissions and
13 # limitations under the License.
14
15 from synapse.rest import admin
16 from synapse.rest.client.v1 import login, room
17
18 from tests import unittest
19
20
21 class ThirdPartyRulesTestModule:
22 def __init__(self, config):
23 pass
24
25 def check_event_allowed(self, event, context):
26 if event.type == "foo.bar.forbidden":
27 return False
28 else:
29 return True
30
31 @staticmethod
32 def parse_config(config):
33 return config
34
35
36 class ThirdPartyRulesTestCase(unittest.HomeserverTestCase):
37 servlets = [
38 admin.register_servlets,
39 login.register_servlets,
40 room.register_servlets,
41 ]
42
43 def make_homeserver(self, reactor, clock):
44 config = self.default_config()
45 config["third_party_event_rules"] = {
46 "module": "tests.rest.client.third_party_rules.ThirdPartyRulesTestModule",
47 "config": {},
48 }
49
50 self.hs = self.setup_test_homeserver(config=config)
51 return self.hs
52
53 def test_third_party_rules(self):
54 """Tests that a forbidden event is forbidden from being sent, but an allowed one
55 can be sent.
56 """
57 user_id = self.register_user("kermit", "monkey")
58 tok = self.login("kermit", "monkey")
59
60 room_id = self.helper.create_room_as(user_id, tok=tok)
61
62 request, channel = self.make_request(
63 "PUT",
64 "/_matrix/client/r0/rooms/%s/send/foo.bar.allowed/1" % room_id,
65 {},
66 access_token=tok,
67 )
68 self.render(request)
69 self.assertEquals(channel.result["code"], b"200", channel.result)
70
71 request, channel = self.make_request(
72 "PUT",
73 "/_matrix/client/r0/rooms/%s/send/foo.bar.forbidden/1" % room_id,
74 {},
75 access_token=tok,
76 )
77 self.render(request)
78 self.assertEquals(channel.result["code"], b"403", channel.result)
2020 from synapse.util.stringutils import random_string
2121
2222 from tests import unittest
23 from tests.unittest import override_config
2324
2425
2526 class DirectoryTestCase(unittest.HomeserverTestCase):
6667 self.ensure_user_joined_room()
6768 self.set_alias_via_directory(400, alias_length=256)
6869
69 def test_state_event_in_room(self):
70 @override_config({"default_room_version": 5})
71 def test_state_event_user_in_v5_room(self):
72 """Test that a regular user can add alias events before room v6"""
7073 self.ensure_user_joined_room()
7174 self.set_alias_via_state_event(200)
75
76 @override_config({"default_room_version": 6})
77 def test_state_event_v6_room(self):
78 """Test that a regular user can *not* add alias events from room v6"""
79 self.ensure_user_joined_room()
80 self.set_alias_via_state_event(403)
7281
7382 def test_directory_in_room(self):
7483 self.ensure_user_joined_room()
4141
4242 hs = self.setup_test_homeserver(config=config)
4343
44 hs.get_handlers().federation_handler = Mock()
44 hs.get_federation_handler = Mock()
4545
4646 return hs
4747
3131 from synapse.util.stringutils import random_string
3232
3333 from tests import unittest
34 from tests.test_utils import make_awaitable
3435
3536 PATH_PREFIX = b"/_matrix/client/api/v1"
3637
4647 "red", http_client=None, federation_client=Mock(),
4748 )
4849
49 self.hs.get_federation_handler = Mock(return_value=Mock())
50 self.hs.get_federation_handler = Mock()
51 self.hs.get_federation_handler.return_value.maybe_backfill = Mock(
52 return_value=make_awaitable(None)
53 )
5054
5155 async def _insert_client_ip(*args, **kwargs):
5256 return None
4343
4444 self.event_source = hs.get_event_sources().sources["typing"]
4545
46 hs.get_handlers().federation_handler = Mock()
46 hs.get_federation_handler = Mock()
4747
4848 async def get_user_by_access_token(token=None, allow_guest=False):
4949 return {
351351 self.render(request)
352352 self.assertEqual(request.code, 401)
353353
354 @unittest.INFO
355354 def test_pending_invites(self):
356355 """Tests that deactivating a user rejects every pending invite for them."""
357356 store = self.hs.get_datastore()
103103 self.assertEqual(len(attempts), 1)
104104 self.assertEqual(attempts[0][0]["response"], "a")
105105
106 @unittest.INFO
107106 def test_fallback_captcha(self):
108107 """Ensure that fallback auth via a captcha works."""
109108 # Returns a 401 as per the spec
00 import json
11 import logging
2 from collections import deque
23 from io import SEEK_END, BytesIO
4 from typing import Callable
35
46 import attr
7 from typing_extensions import Deque
58 from zope.interface import implementer
69
710 from twisted.internet import address, threads, udp
250253 self._tcp_callbacks = {}
251254 self._udp = []
252255 lookups = self.lookups = {}
256 self._thread_callbacks = deque() # type: Deque[Callable[[], None]]()
253257
254258 @implementer(IResolverSimple)
255259 class FakeResolver:
271275 """
272276 Make the callback fire in the next reactor iteration.
273277 """
274 d = Deferred()
275 d.addCallback(lambda x: callback(*args, **kwargs))
276 self.callLater(0, d.callback, True)
277 return d
278 cb = lambda: callback(*args, **kwargs)
279 # it's not safe to call callLater() here, so we append the callback to a
280 # separate queue.
281 self._thread_callbacks.append(cb)
278282
279283 def getThreadPool(self):
280284 return self.threadpool
301305 callback()
302306
303307 return conn
308
309 def advance(self, amount):
310 # first advance our reactor's time, and run any "callLater" callbacks that
311 # makes ready
312 super().advance(amount)
313
314 # now run any "callFromThread" callbacks
315 while True:
316 try:
317 callback = self._thread_callbacks.popleft()
318 except IndexError:
319 break
320 callback()
321
322 # check for more "callLater" callbacks added by the thread callback
323 # This isn't required in a regular reactor, but it ends up meaning that
324 # our database queries can complete in a single call to `advance` [1] which
325 # simplifies tests.
326 #
327 # [1]: we replace the threadpool backing the db connection pool with a
328 # mock ThreadPool which doesn't really use threads; but we still use
329 # reactor.callFromThread to feed results back from the db functions to the
330 # main thread.
331 super().advance(0)
304332
305333
306334 class ThreadPool:
337365 the homeserver.
338366 """
339367 server = _sth(cleanup_func, *args, **kwargs)
340
341 database = server.config.database.get_single_database()
342368
343369 # Make the thread pool synchronous.
344370 clock = server.get_clock()
371397 pool.threadpool = ThreadPool(clock._reactor)
372398 pool.running = True
373399
400 # We've just changed the Databases to run DB transactions on the same
401 # thread, so we need to disable the dedicated thread behaviour.
402 server.get_datastores().main.USE_DEDICATED_DB_THREADS_FOR_EVENT_FETCHING = False
403
374404 return server
375405
376406
1414 # limitations under the License.
1515
1616
17 from mock import Mock
18
19 from twisted.internet import defer
20
21 from synapse.util.async_helpers import ObservableDeferred
22 from synapse.util.caches.descriptors import Cache, cached
23
2417 from tests import unittest
25
26
27 class CacheTestCase(unittest.HomeserverTestCase):
28 def prepare(self, reactor, clock, homeserver):
29 self.cache = Cache("test")
30
31 def test_empty(self):
32 failed = False
33 try:
34 self.cache.get("foo")
35 except KeyError:
36 failed = True
37
38 self.assertTrue(failed)
39
40 def test_hit(self):
41 self.cache.prefill("foo", 123)
42
43 self.assertEquals(self.cache.get("foo"), 123)
44
45 def test_invalidate(self):
46 self.cache.prefill(("foo",), 123)
47 self.cache.invalidate(("foo",))
48
49 failed = False
50 try:
51 self.cache.get(("foo",))
52 except KeyError:
53 failed = True
54
55 self.assertTrue(failed)
56
57 def test_eviction(self):
58 cache = Cache("test", max_entries=2)
59
60 cache.prefill(1, "one")
61 cache.prefill(2, "two")
62 cache.prefill(3, "three") # 1 will be evicted
63
64 failed = False
65 try:
66 cache.get(1)
67 except KeyError:
68 failed = True
69
70 self.assertTrue(failed)
71
72 cache.get(2)
73 cache.get(3)
74
75 def test_eviction_lru(self):
76 cache = Cache("test", max_entries=2)
77
78 cache.prefill(1, "one")
79 cache.prefill(2, "two")
80
81 # Now access 1 again, thus causing 2 to be least-recently used
82 cache.get(1)
83
84 cache.prefill(3, "three")
85
86 failed = False
87 try:
88 cache.get(2)
89 except KeyError:
90 failed = True
91
92 self.assertTrue(failed)
93
94 cache.get(1)
95 cache.get(3)
96
97
98 class CacheDecoratorTestCase(unittest.HomeserverTestCase):
99 @defer.inlineCallbacks
100 def test_passthrough(self):
101 class A:
102 @cached()
103 def func(self, key):
104 return key
105
106 a = A()
107
108 self.assertEquals((yield a.func("foo")), "foo")
109 self.assertEquals((yield a.func("bar")), "bar")
110
111 @defer.inlineCallbacks
112 def test_hit(self):
113 callcount = [0]
114
115 class A:
116 @cached()
117 def func(self, key):
118 callcount[0] += 1
119 return key
120
121 a = A()
122 yield a.func("foo")
123
124 self.assertEquals(callcount[0], 1)
125
126 self.assertEquals((yield a.func("foo")), "foo")
127 self.assertEquals(callcount[0], 1)
128
129 @defer.inlineCallbacks
130 def test_invalidate(self):
131 callcount = [0]
132
133 class A:
134 @cached()
135 def func(self, key):
136 callcount[0] += 1
137 return key
138
139 a = A()
140 yield a.func("foo")
141
142 self.assertEquals(callcount[0], 1)
143
144 a.func.invalidate(("foo",))
145
146 yield a.func("foo")
147
148 self.assertEquals(callcount[0], 2)
149
150 def test_invalidate_missing(self):
151 class A:
152 @cached()
153 def func(self, key):
154 return key
155
156 A().func.invalidate(("what",))
157
158 @defer.inlineCallbacks
159 def test_max_entries(self):
160 callcount = [0]
161
162 class A:
163 @cached(max_entries=10)
164 def func(self, key):
165 callcount[0] += 1
166 return key
167
168 a = A()
169
170 for k in range(0, 12):
171 yield a.func(k)
172
173 self.assertEquals(callcount[0], 12)
174
175 # There must have been at least 2 evictions, meaning if we calculate
176 # all 12 values again, we must get called at least 2 more times
177 for k in range(0, 12):
178 yield a.func(k)
179
180 self.assertTrue(
181 callcount[0] >= 14, msg="Expected callcount >= 14, got %d" % (callcount[0])
182 )
183
184 def test_prefill(self):
185 callcount = [0]
186
187 d = defer.succeed(123)
188
189 class A:
190 @cached()
191 def func(self, key):
192 callcount[0] += 1
193 return d
194
195 a = A()
196
197 a.func.prefill(("foo",), ObservableDeferred(d))
198
199 self.assertEquals(a.func("foo").result, d.result)
200 self.assertEquals(callcount[0], 0)
201
202 @defer.inlineCallbacks
203 def test_invalidate_context(self):
204 callcount = [0]
205 callcount2 = [0]
206
207 class A:
208 @cached()
209 def func(self, key):
210 callcount[0] += 1
211 return key
212
213 @cached(cache_context=True)
214 def func2(self, key, cache_context):
215 callcount2[0] += 1
216 return self.func(key, on_invalidate=cache_context.invalidate)
217
218 a = A()
219 yield a.func2("foo")
220
221 self.assertEquals(callcount[0], 1)
222 self.assertEquals(callcount2[0], 1)
223
224 a.func.invalidate(("foo",))
225 yield a.func("foo")
226
227 self.assertEquals(callcount[0], 2)
228 self.assertEquals(callcount2[0], 1)
229
230 yield a.func2("foo")
231
232 self.assertEquals(callcount[0], 2)
233 self.assertEquals(callcount2[0], 2)
234
235 @defer.inlineCallbacks
236 def test_eviction_context(self):
237 callcount = [0]
238 callcount2 = [0]
239
240 class A:
241 @cached(max_entries=2)
242 def func(self, key):
243 callcount[0] += 1
244 return key
245
246 @cached(cache_context=True)
247 def func2(self, key, cache_context):
248 callcount2[0] += 1
249 return self.func(key, on_invalidate=cache_context.invalidate)
250
251 a = A()
252 yield a.func2("foo")
253 yield a.func2("foo2")
254
255 self.assertEquals(callcount[0], 2)
256 self.assertEquals(callcount2[0], 2)
257
258 yield a.func2("foo")
259 self.assertEquals(callcount[0], 2)
260 self.assertEquals(callcount2[0], 2)
261
262 yield a.func("foo3")
263
264 self.assertEquals(callcount[0], 3)
265 self.assertEquals(callcount2[0], 2)
266
267 yield a.func2("foo")
268
269 self.assertEquals(callcount[0], 4)
270 self.assertEquals(callcount2[0], 3)
271
272 @defer.inlineCallbacks
273 def test_double_get(self):
274 callcount = [0]
275 callcount2 = [0]
276
277 class A:
278 @cached()
279 def func(self, key):
280 callcount[0] += 1
281 return key
282
283 @cached(cache_context=True)
284 def func2(self, key, cache_context):
285 callcount2[0] += 1
286 return self.func(key, on_invalidate=cache_context.invalidate)
287
288 a = A()
289 a.func2.cache.cache = Mock(wraps=a.func2.cache.cache)
290
291 yield a.func2("foo")
292
293 self.assertEquals(callcount[0], 1)
294 self.assertEquals(callcount2[0], 1)
295
296 a.func2.invalidate(("foo",))
297 self.assertEquals(a.func2.cache.cache.pop.call_count, 1)
298
299 yield a.func2("foo")
300 a.func2.invalidate(("foo",))
301 self.assertEquals(a.func2.cache.cache.pop.call_count, 2)
302
303 self.assertEquals(callcount[0], 1)
304 self.assertEquals(callcount2[0], 2)
305
306 a.func.invalidate(("foo",))
307 self.assertEquals(a.func2.cache.cache.pop.call_count, 3)
308 yield a.func("foo")
309
310 self.assertEquals(callcount[0], 2)
311 self.assertEquals(callcount2[0], 2)
312
313 yield a.func2("foo")
314
315 self.assertEquals(callcount[0], 2)
316 self.assertEquals(callcount2[0], 3)
31718
31819
31920 class UpsertManyTests(unittest.HomeserverTestCase):
5757 # must be done after inserts
5858 database = hs.get_datastores().databases[0]
5959 self.store = ApplicationServiceStore(
60 database, make_conn(database._database_config, database.engine), hs
60 database, make_conn(database._database_config, database.engine, "test"), hs
6161 )
6262
6363 def tearDown(self):
131131
132132 db_config = hs.config.get_single_database()
133133 self.store = TestTransactionStore(
134 database, make_conn(db_config, self.engine), hs
134 database, make_conn(db_config, self.engine, "test"), hs
135135 )
136136
137137 def _add_service(self, url, as_token, id):
243243 service = Mock(id=self.as_list[0]["id"])
244244 events = [Mock(event_id="e1"), Mock(event_id="e2")]
245245 txn = yield defer.ensureDeferred(
246 self.store.create_appservice_txn(service, events)
246 self.store.create_appservice_txn(service, events, [])
247247 )
248248 self.assertEquals(txn.id, 1)
249249 self.assertEquals(txn.events, events)
257257 yield self._insert_txn(service.id, 9644, events)
258258 yield self._insert_txn(service.id, 9645, events)
259259 txn = yield defer.ensureDeferred(
260 self.store.create_appservice_txn(service, events)
260 self.store.create_appservice_txn(service, events, [])
261261 )
262262 self.assertEquals(txn.id, 9646)
263263 self.assertEquals(txn.events, events)
269269 events = [Mock(event_id="e1"), Mock(event_id="e2")]
270270 yield self._set_last_txn(service.id, 9643)
271271 txn = yield defer.ensureDeferred(
272 self.store.create_appservice_txn(service, events)
272 self.store.create_appservice_txn(service, events, [])
273273 )
274274 self.assertEquals(txn.id, 9644)
275275 self.assertEquals(txn.events, events)
292292 yield self._insert_txn(self.as_list[3]["id"], 9643, events)
293293
294294 txn = yield defer.ensureDeferred(
295 self.store.create_appservice_txn(service, events)
295 self.store.create_appservice_txn(service, events, [])
296296 )
297297 self.assertEquals(txn.id, 9644)
298298 self.assertEquals(txn.events, events)
406406 self.assertEquals(
407407 {self.as_list[2]["id"], self.as_list[0]["id"]},
408408 {services[0].id, services[1].id},
409 )
410
411
412 class ApplicationServiceStoreTypeStreamIds(unittest.HomeserverTestCase):
413 def make_homeserver(self, reactor, clock):
414 hs = self.setup_test_homeserver()
415 return hs
416
417 def prepare(self, hs, reactor, clock):
418 self.service = Mock(id="foo")
419 self.store = self.hs.get_datastore()
420 self.get_success(self.store.set_appservice_state(self.service, "up"))
421
422 def test_get_type_stream_id_for_appservice_no_value(self):
423 value = self.get_success(
424 self.store.get_type_stream_id_for_appservice(self.service, "read_receipt")
425 )
426 self.assertEquals(value, 0)
427
428 value = self.get_success(
429 self.store.get_type_stream_id_for_appservice(self.service, "presence")
430 )
431 self.assertEquals(value, 0)
432
433 def test_get_type_stream_id_for_appservice_invalid_type(self):
434 self.get_failure(
435 self.store.get_type_stream_id_for_appservice(self.service, "foobar"),
436 ValueError,
437 )
438
439 def test_set_type_stream_id_for_appservice(self):
440 read_receipt_value = 1024
441 self.get_success(
442 self.store.set_type_stream_id_for_appservice(
443 self.service, "read_receipt", read_receipt_value
444 )
445 )
446 result = self.get_success(
447 self.store.get_type_stream_id_for_appservice(self.service, "read_receipt")
448 )
449 self.assertEqual(result, read_receipt_value)
450
451 self.get_success(
452 self.store.set_type_stream_id_for_appservice(
453 self.service, "presence", read_receipt_value
454 )
455 )
456 result = self.get_success(
457 self.store.get_type_stream_id_for_appservice(self.service, "presence")
458 )
459 self.assertEqual(result, read_receipt_value)
460
461 def test_set_type_stream_id_for_appservice_invalid_type(self):
462 self.get_failure(
463 self.store.set_type_stream_id_for_appservice(self.service, "foobar", 1024),
464 ValueError,
409465 )
410466
411467
447503
448504 database = hs.get_datastores().databases[0]
449505 ApplicationServiceStore(
450 database, make_conn(database._database_config, database.engine), hs
506 database, make_conn(database._database_config, database.engine, "test"), hs
451507 )
452508
453509 @defer.inlineCallbacks
466522 with self.assertRaises(ConfigError) as cm:
467523 database = hs.get_datastores().databases[0]
468524 ApplicationServiceStore(
469 database, make_conn(database._database_config, database.engine), hs
525 database,
526 make_conn(database._database_config, database.engine, "test"),
527 hs,
470528 )
471529
472530 e = cm.exception
490548 with self.assertRaises(ConfigError) as cm:
491549 database = hs.get_datastores().databases[0]
492550 ApplicationServiceStore(
493 database, make_conn(database._database_config, database.engine), hs
551 database,
552 make_conn(database._database_config, database.engine, "test"),
553 hs,
494554 )
495555
496556 e = cm.exception
198198 first_id_gen = self._create_id_generator("first", writers=["first", "second"])
199199 second_id_gen = self._create_id_generator("second", writers=["first", "second"])
200200
201 self.assertEqual(first_id_gen.get_positions(), {"first": 3, "second": 7})
202 self.assertEqual(first_id_gen.get_current_token_for_writer("first"), 3)
201 # The first ID gen will notice that it can advance its token to 7 as it
202 # has no in progress writes...
203 self.assertEqual(first_id_gen.get_positions(), {"first": 7, "second": 7})
204 self.assertEqual(first_id_gen.get_current_token_for_writer("first"), 7)
203205 self.assertEqual(first_id_gen.get_current_token_for_writer("second"), 7)
206
207 # ... but the second ID gen doesn't know that.
208 self.assertEqual(second_id_gen.get_positions(), {"first": 3, "second": 7})
209 self.assertEqual(second_id_gen.get_current_token_for_writer("first"), 3)
210 self.assertEqual(second_id_gen.get_current_token_for_writer("second"), 7)
204211
205212 # Try allocating a new ID gen and check that we only see position
206213 # advanced after we leave the context manager.
210217 self.assertEqual(stream_id, 8)
211218
212219 self.assertEqual(
213 first_id_gen.get_positions(), {"first": 3, "second": 7}
220 first_id_gen.get_positions(), {"first": 7, "second": 7}
214221 )
215222
216223 self.get_success(_get_next_async())
278285 self._insert_row_with_id("first", 3)
279286 self._insert_row_with_id("second", 5)
280287
281 id_gen = self._create_id_generator("first", writers=["first", "second"])
288 id_gen = self._create_id_generator("worker", writers=["first", "second"])
282289
283290 self.assertEqual(id_gen.get_positions(), {"first": 3, "second": 5})
284291
318325
319326 id_gen = self._create_id_generator("first", writers=["first", "second"])
320327
321 self.assertEqual(id_gen.get_positions(), {"first": 3, "second": 5})
322
323 self.assertEqual(id_gen.get_persisted_upto_position(), 3)
328 self.assertEqual(id_gen.get_positions(), {"first": 5, "second": 5})
329
330 self.assertEqual(id_gen.get_persisted_upto_position(), 5)
324331
325332 async def _get_next_async():
326333 async with id_gen.get_next() as stream_id:
327334 self.assertEqual(stream_id, 6)
328 self.assertEqual(id_gen.get_persisted_upto_position(), 3)
335 self.assertEqual(id_gen.get_persisted_upto_position(), 5)
329336
330337 self.get_success(_get_next_async())
331338
387394 self._insert_row_with_id("second", 5)
388395
389396 # Initial config has two writers
390 id_gen = self._create_id_generator("first", writers=["first", "second"])
397 id_gen = self._create_id_generator("worker", writers=["first", "second"])
391398 self.assertEqual(id_gen.get_persisted_upto_position(), 3)
392399 self.assertEqual(id_gen.get_current_token_for_writer("first"), 3)
393400 self.assertEqual(id_gen.get_current_token_for_writer("second"), 5)
567574
568575 self.get_success(_get_next_async2())
569576
570 self.assertEqual(id_gen_1.get_positions(), {"first": -1, "second": -2})
577 self.assertEqual(id_gen_1.get_positions(), {"first": -2, "second": -2})
571578 self.assertEqual(id_gen_2.get_positions(), {"first": -1, "second": -2})
572579 self.assertEqual(id_gen_1.get_persisted_upto_position(), -2)
573580 self.assertEqual(id_gen_2.get_persisted_upto_position(), -2)
235235 self._event_id = event_id
236236
237237 @defer.inlineCallbacks
238 def build(self, prev_event_ids):
238 def build(self, prev_event_ids, auth_event_ids):
239239 built_event = yield defer.ensureDeferred(
240 self._base_builder.build(prev_event_ids)
240 self._base_builder.build(prev_event_ids, auth_event_ids)
241241 )
242242
243243 built_event._event_id = self._event_id
7474 }
7575 )
7676
77 self.handler = self.homeserver.get_handlers().federation_handler
77 self.handler = self.homeserver.get_federation_handler()
7878 self.handler.do_auth = lambda origin, event, context, auth_events: succeed(
7979 context
8080 )
1414 # limitations under the License.
1515
1616 from synapse.metrics import REGISTRY, InFlightGauge, generate_latest
17 from synapse.util.caches.descriptors import Cache
17 from synapse.util.caches.deferred_cache import DeferredCache
1818
1919 from tests import unittest
2020
137137 Caches produce metrics reflecting their state when scraped.
138138 """
139139 CACHE_NAME = "cache_metrics_test_fgjkbdfg"
140 cache = Cache(CACHE_NAME, max_entries=777)
140 cache = DeferredCache(CACHE_NAME, max_entries=777)
141141
142142 items = {
143143 x.split(b"{")[0].decode("ascii"): x.split(b" ")[1].decode("ascii")
1616
1717 import mock
1818
19 from synapse.app.homeserver import phone_stats_home
19 from synapse.app.phone_stats_home import phone_stats_home
2020
2121 from tests.unittest import HomeserverTestCase
2222
1919 import inspect
2020 import logging
2121 import time
22 from typing import Optional, Tuple, Type, TypeVar, Union
22 from typing import Optional, Tuple, Type, TypeVar, Union, overload
2323
2424 from mock import Mock, patch
2525
240240 # create a site to wrap the resource.
241241 self.site = SynapseSite(
242242 logger_name="synapse.access.http.fake",
243 site_tag="test",
243 site_tag=self.hs.config.server.server_name,
244244 config=self.hs.config.server.listeners[0],
245245 resource=self.resource,
246246 server_version_string="1",
253253 if hasattr(self, "user_id"):
254254 if self.hijack_auth:
255255
256 # We need a valid token ID to satisfy foreign key constraints.
257 token_id = self.get_success(
258 self.hs.get_datastore().add_access_token_to_user(
259 self.helper.auth_user_id, "some_fake_token", None, None,
260 )
261 )
262
256263 async def get_user_by_access_token(token=None, allow_guest=False):
257264 return {
258265 "user": UserID.from_string(self.helper.auth_user_id),
259 "token_id": 1,
266 "token_id": token_id,
260267 "is_guest": False,
261268 }
262269
263270 async def get_user_by_req(request, allow_guest=False, rights="access"):
264271 return create_requester(
265272 UserID.from_string(self.helper.auth_user_id),
266 1,
273 token_id,
267274 False,
268275 False,
269276 None,
356363 Function to optionally be overridden in subclasses.
357364 """
358365
366 # Annoyingly mypy doesn't seem to pick up the fact that T is SynapseRequest
367 # when the `request` arg isn't given, so we define an explicit override to
368 # cover that case.
369 @overload
370 def make_request(
371 self,
372 method: Union[bytes, str],
373 path: Union[bytes, str],
374 content: Union[bytes, dict] = b"",
375 access_token: Optional[str] = None,
376 shorthand: bool = True,
377 federation_auth_origin: str = None,
378 content_is_form: bool = False,
379 ) -> Tuple[SynapseRequest, FakeChannel]:
380 ...
381
382 @overload
359383 def make_request(
360384 self,
361385 method: Union[bytes, str],
367391 federation_auth_origin: str = None,
368392 content_is_form: bool = False,
369393 ) -> Tuple[T, FakeChannel]:
394 ...
395
396 def make_request(
397 self,
398 method: Union[bytes, str],
399 path: Union[bytes, str],
400 content: Union[bytes, dict] = b"",
401 access_token: Optional[str] = None,
402 request: Type[T] = SynapseRequest,
403 shorthand: bool = True,
404 federation_auth_origin: str = None,
405 content_is_form: bool = False,
406 ) -> Tuple[T, FakeChannel]:
370407 """
371408 Create a SynapseRequest at the path using the method and containing the
372409 given content.
607644 if soft_failed:
608645 event.internal_metadata.soft_failed = True
609646
610 self.get_success(event_creator.send_nonmember_event(requester, event, context))
647 self.get_success(
648 event_creator.handle_new_client_event(requester, event, context)
649 )
611650
612651 return event.event_id
613652
0 # -*- coding: utf-8 -*-
1 # Copyright 2020 The Matrix.org Foundation C.I.C.
2 #
3 # Licensed under the Apache License, Version 2.0 (the "License");
4 # you may not use this file except in compliance with the License.
5 # You may obtain a copy of the License at
6 #
7 # http://www.apache.org/licenses/LICENSE-2.0
8 #
9 # Unless required by applicable law or agreed to in writing, software
10 # distributed under the License is distributed on an "AS IS" BASIS,
11 # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 # See the License for the specific language governing permissions and
13 # limitations under the License.
14
15 from functools import partial
16
17 from twisted.internet import defer
18
19 from synapse.util.caches.deferred_cache import DeferredCache
20
21 from tests.unittest import TestCase
22
23
24 class DeferredCacheTestCase(TestCase):
25 def test_empty(self):
26 cache = DeferredCache("test")
27 failed = False
28 try:
29 cache.get("foo")
30 except KeyError:
31 failed = True
32
33 self.assertTrue(failed)
34
35 def test_hit(self):
36 cache = DeferredCache("test")
37 cache.prefill("foo", 123)
38
39 self.assertEquals(self.successResultOf(cache.get("foo")), 123)
40
41 def test_hit_deferred(self):
42 cache = DeferredCache("test")
43 origin_d = defer.Deferred()
44 set_d = cache.set("k1", origin_d)
45
46 # get should return an incomplete deferred
47 get_d = cache.get("k1")
48 self.assertFalse(get_d.called)
49
50 # add a callback that will make sure that the set_d gets called before the get_d
51 def check1(r):
52 self.assertTrue(set_d.called)
53 return r
54
55 # TODO: Actually ObservableDeferred *doesn't* run its tests in order on py3.8.
56 # maybe we should fix that?
57 # get_d.addCallback(check1)
58
59 # now fire off all the deferreds
60 origin_d.callback(99)
61 self.assertEqual(self.successResultOf(origin_d), 99)
62 self.assertEqual(self.successResultOf(set_d), 99)
63 self.assertEqual(self.successResultOf(get_d), 99)
64
65 def test_callbacks(self):
66 """Invalidation callbacks are called at the right time"""
67 cache = DeferredCache("test")
68 callbacks = set()
69
70 # start with an entry, with a callback
71 cache.prefill("k1", 10, callback=lambda: callbacks.add("prefill"))
72
73 # now replace that entry with a pending result
74 origin_d = defer.Deferred()
75 set_d = cache.set("k1", origin_d, callback=lambda: callbacks.add("set"))
76
77 # ... and also make a get request
78 get_d = cache.get("k1", callback=lambda: callbacks.add("get"))
79
80 # we don't expect the invalidation callback for the original value to have
81 # been called yet, even though get() will now return a different result.
82 # I'm not sure if that is by design or not.
83 self.assertEqual(callbacks, set())
84
85 # now fire off all the deferreds
86 origin_d.callback(20)
87 self.assertEqual(self.successResultOf(set_d), 20)
88 self.assertEqual(self.successResultOf(get_d), 20)
89
90 # now the original invalidation callback should have been called, but none of
91 # the others
92 self.assertEqual(callbacks, {"prefill"})
93 callbacks.clear()
94
95 # another update should invalidate both the previous results
96 cache.prefill("k1", 30)
97 self.assertEqual(callbacks, {"set", "get"})
98
99 def test_set_fail(self):
100 cache = DeferredCache("test")
101 callbacks = set()
102
103 # start with an entry, with a callback
104 cache.prefill("k1", 10, callback=lambda: callbacks.add("prefill"))
105
106 # now replace that entry with a pending result
107 origin_d = defer.Deferred()
108 set_d = cache.set("k1", origin_d, callback=lambda: callbacks.add("set"))
109
110 # ... and also make a get request
111 get_d = cache.get("k1", callback=lambda: callbacks.add("get"))
112
113 # none of the callbacks should have been called yet
114 self.assertEqual(callbacks, set())
115
116 # oh noes! fails!
117 e = Exception("oops")
118 origin_d.errback(e)
119 self.assertIs(self.failureResultOf(set_d, Exception).value, e)
120 self.assertIs(self.failureResultOf(get_d, Exception).value, e)
121
122 # the callbacks for the failed requests should have been called.
123 # I'm not sure if this is deliberate or not.
124 self.assertEqual(callbacks, {"get", "set"})
125 callbacks.clear()
126
127 # the old value should still be returned now?
128 get_d2 = cache.get("k1", callback=lambda: callbacks.add("get2"))
129 self.assertEqual(self.successResultOf(get_d2), 10)
130
131 # replacing the value now should run the callbacks for those requests
132 # which got the original result
133 cache.prefill("k1", 30)
134 self.assertEqual(callbacks, {"prefill", "get2"})
135
136 def test_get_immediate(self):
137 cache = DeferredCache("test")
138 d1 = defer.Deferred()
139 cache.set("key1", d1)
140
141 # get_immediate should return default
142 v = cache.get_immediate("key1", 1)
143 self.assertEqual(v, 1)
144
145 # now complete the set
146 d1.callback(2)
147
148 # get_immediate should return result
149 v = cache.get_immediate("key1", 1)
150 self.assertEqual(v, 2)
151
152 def test_invalidate(self):
153 cache = DeferredCache("test")
154 cache.prefill(("foo",), 123)
155 cache.invalidate(("foo",))
156
157 failed = False
158 try:
159 cache.get(("foo",))
160 except KeyError:
161 failed = True
162
163 self.assertTrue(failed)
164
165 def test_invalidate_all(self):
166 cache = DeferredCache("testcache")
167
168 callback_record = [False, False]
169
170 def record_callback(idx):
171 callback_record[idx] = True
172
173 # add a couple of pending entries
174 d1 = defer.Deferred()
175 cache.set("key1", d1, partial(record_callback, 0))
176
177 d2 = defer.Deferred()
178 cache.set("key2", d2, partial(record_callback, 1))
179
180 # lookup should return pending deferreds
181 self.assertFalse(cache.get("key1").called)
182 self.assertFalse(cache.get("key2").called)
183
184 # let one of the lookups complete
185 d2.callback("result2")
186
187 # now the cache will return a completed deferred
188 self.assertEqual(self.successResultOf(cache.get("key2")), "result2")
189
190 # now do the invalidation
191 cache.invalidate_all()
192
193 # lookup should fail
194 with self.assertRaises(KeyError):
195 cache.get("key1")
196 with self.assertRaises(KeyError):
197 cache.get("key2")
198
199 # both callbacks should have been callbacked
200 self.assertTrue(callback_record[0], "Invalidation callback for key1 not called")
201 self.assertTrue(callback_record[1], "Invalidation callback for key2 not called")
202
203 # letting the other lookup complete should do nothing
204 d1.callback("result1")
205 with self.assertRaises(KeyError):
206 cache.get("key1", None)
207
208 def test_eviction(self):
209 cache = DeferredCache(
210 "test", max_entries=2, apply_cache_factor_from_config=False
211 )
212
213 cache.prefill(1, "one")
214 cache.prefill(2, "two")
215 cache.prefill(3, "three") # 1 will be evicted
216
217 failed = False
218 try:
219 cache.get(1)
220 except KeyError:
221 failed = True
222
223 self.assertTrue(failed)
224
225 cache.get(2)
226 cache.get(3)
227
228 def test_eviction_lru(self):
229 cache = DeferredCache(
230 "test", max_entries=2, apply_cache_factor_from_config=False
231 )
232
233 cache.prefill(1, "one")
234 cache.prefill(2, "two")
235
236 # Now access 1 again, thus causing 2 to be least-recently used
237 cache.get(1)
238
239 cache.prefill(3, "three")
240
241 failed = False
242 try:
243 cache.get(2)
244 except KeyError:
245 failed = True
246
247 self.assertTrue(failed)
248
249 cache.get(1)
250 cache.get(3)
1313 # See the License for the specific language governing permissions and
1414 # limitations under the License.
1515 import logging
16 from functools import partial
16 from typing import Set
1717
1818 import mock
1919
4141 return make_deferred_yieldable(d)
4242
4343
44 class CacheTestCase(unittest.TestCase):
45 def test_invalidate_all(self):
46 cache = descriptors.Cache("testcache")
47
48 callback_record = [False, False]
49
50 def record_callback(idx):
51 callback_record[idx] = True
52
53 # add a couple of pending entries
54 d1 = defer.Deferred()
55 cache.set("key1", d1, partial(record_callback, 0))
56
57 d2 = defer.Deferred()
58 cache.set("key2", d2, partial(record_callback, 1))
59
60 # lookup should return observable deferreds
61 self.assertFalse(cache.get("key1").has_called())
62 self.assertFalse(cache.get("key2").has_called())
63
64 # let one of the lookups complete
65 d2.callback("result2")
66
67 # for now at least, the cache will return real results rather than an
68 # observabledeferred
69 self.assertEqual(cache.get("key2"), "result2")
70
71 # now do the invalidation
72 cache.invalidate_all()
73
74 # lookup should return none
75 self.assertIsNone(cache.get("key1", None))
76 self.assertIsNone(cache.get("key2", None))
77
78 # both callbacks should have been callbacked
79 self.assertTrue(callback_record[0], "Invalidation callback for key1 not called")
80 self.assertTrue(callback_record[1], "Invalidation callback for key2 not called")
81
82 # letting the other lookup complete should do nothing
83 d1.callback("result1")
84 self.assertIsNone(cache.get("key1", None))
85
86
8744 class DescriptorTestCase(unittest.TestCase):
8845 @defer.inlineCallbacks
8946 def test_cache(self):
172129 # and a second call should result in a second exception
173130 d = obj.fn(1)
174131 self.failureResultOf(d, SynapseError)
132
133 def test_cache_with_async_exception(self):
134 """The wrapped function returns a failure
135 """
136
137 class Cls:
138 result = None
139 call_count = 0
140
141 @cached()
142 def fn(self, arg1):
143 self.call_count += 1
144 return self.result
145
146 obj = Cls()
147 callbacks = set() # type: Set[str]
148
149 # set off an asynchronous request
150 obj.result = origin_d = defer.Deferred()
151
152 d1 = obj.fn(1, on_invalidate=lambda: callbacks.add("d1"))
153 self.assertFalse(d1.called)
154
155 # a second request should also return a deferred, but should not call the
156 # function itself.
157 d2 = obj.fn(1, on_invalidate=lambda: callbacks.add("d2"))
158 self.assertFalse(d2.called)
159 self.assertEqual(obj.call_count, 1)
160
161 # no callbacks yet
162 self.assertEqual(callbacks, set())
163
164 # the original request fails
165 e = Exception("bzz")
166 origin_d.errback(e)
167
168 # ... which should cause the lookups to fail similarly
169 self.assertIs(self.failureResultOf(d1, Exception).value, e)
170 self.assertIs(self.failureResultOf(d2, Exception).value, e)
171
172 # ... and the callbacks to have been, uh, called.
173 self.assertEqual(callbacks, {"d1", "d2"})
174
175 # ... leaving the cache empty
176 self.assertEqual(len(obj.fn.cache.cache), 0)
177
178 # and a second call should work as normal
179 obj.result = defer.succeed(100)
180 d3 = obj.fn(1)
181 self.assertEqual(self.successResultOf(d3), 100)
182 self.assertEqual(obj.call_count, 2)
175183
176184 def test_cache_logcontexts(self):
177185 """Check that logcontexts are set and restored correctly when
354362 self.failureResultOf(d, SynapseError)
355363
356364
365 class CacheDecoratorTestCase(unittest.HomeserverTestCase):
366 """More tests for @cached
367
368 The following is a set of tests that got lost in a different file for a while.
369
370 There are probably duplicates of the tests in DescriptorTestCase. Ideally the
371 duplicates would be removed and the two sets of classes combined.
372 """
373
374 @defer.inlineCallbacks
375 def test_passthrough(self):
376 class A:
377 @cached()
378 def func(self, key):
379 return key
380
381 a = A()
382
383 self.assertEquals((yield a.func("foo")), "foo")
384 self.assertEquals((yield a.func("bar")), "bar")
385
386 @defer.inlineCallbacks
387 def test_hit(self):
388 callcount = [0]
389
390 class A:
391 @cached()
392 def func(self, key):
393 callcount[0] += 1
394 return key
395
396 a = A()
397 yield a.func("foo")
398
399 self.assertEquals(callcount[0], 1)
400
401 self.assertEquals((yield a.func("foo")), "foo")
402 self.assertEquals(callcount[0], 1)
403
404 @defer.inlineCallbacks
405 def test_invalidate(self):
406 callcount = [0]
407
408 class A:
409 @cached()
410 def func(self, key):
411 callcount[0] += 1
412 return key
413
414 a = A()
415 yield a.func("foo")
416
417 self.assertEquals(callcount[0], 1)
418
419 a.func.invalidate(("foo",))
420
421 yield a.func("foo")
422
423 self.assertEquals(callcount[0], 2)
424
425 def test_invalidate_missing(self):
426 class A:
427 @cached()
428 def func(self, key):
429 return key
430
431 A().func.invalidate(("what",))
432
433 @defer.inlineCallbacks
434 def test_max_entries(self):
435 callcount = [0]
436
437 class A:
438 @cached(max_entries=10)
439 def func(self, key):
440 callcount[0] += 1
441 return key
442
443 a = A()
444
445 for k in range(0, 12):
446 yield a.func(k)
447
448 self.assertEquals(callcount[0], 12)
449
450 # There must have been at least 2 evictions, meaning if we calculate
451 # all 12 values again, we must get called at least 2 more times
452 for k in range(0, 12):
453 yield a.func(k)
454
455 self.assertTrue(
456 callcount[0] >= 14, msg="Expected callcount >= 14, got %d" % (callcount[0])
457 )
458
459 def test_prefill(self):
460 callcount = [0]
461
462 d = defer.succeed(123)
463
464 class A:
465 @cached()
466 def func(self, key):
467 callcount[0] += 1
468 return d
469
470 a = A()
471
472 a.func.prefill(("foo",), 456)
473
474 self.assertEquals(a.func("foo").result, 456)
475 self.assertEquals(callcount[0], 0)
476
477 @defer.inlineCallbacks
478 def test_invalidate_context(self):
479 callcount = [0]
480 callcount2 = [0]
481
482 class A:
483 @cached()
484 def func(self, key):
485 callcount[0] += 1
486 return key
487
488 @cached(cache_context=True)
489 def func2(self, key, cache_context):
490 callcount2[0] += 1
491 return self.func(key, on_invalidate=cache_context.invalidate)
492
493 a = A()
494 yield a.func2("foo")
495
496 self.assertEquals(callcount[0], 1)
497 self.assertEquals(callcount2[0], 1)
498
499 a.func.invalidate(("foo",))
500 yield a.func("foo")
501
502 self.assertEquals(callcount[0], 2)
503 self.assertEquals(callcount2[0], 1)
504
505 yield a.func2("foo")
506
507 self.assertEquals(callcount[0], 2)
508 self.assertEquals(callcount2[0], 2)
509
510 @defer.inlineCallbacks
511 def test_eviction_context(self):
512 callcount = [0]
513 callcount2 = [0]
514
515 class A:
516 @cached(max_entries=2)
517 def func(self, key):
518 callcount[0] += 1
519 return key
520
521 @cached(cache_context=True)
522 def func2(self, key, cache_context):
523 callcount2[0] += 1
524 return self.func(key, on_invalidate=cache_context.invalidate)
525
526 a = A()
527 yield a.func2("foo")
528 yield a.func2("foo2")
529
530 self.assertEquals(callcount[0], 2)
531 self.assertEquals(callcount2[0], 2)
532
533 yield a.func2("foo")
534 self.assertEquals(callcount[0], 2)
535 self.assertEquals(callcount2[0], 2)
536
537 yield a.func("foo3")
538
539 self.assertEquals(callcount[0], 3)
540 self.assertEquals(callcount2[0], 2)
541
542 yield a.func2("foo")
543
544 self.assertEquals(callcount[0], 4)
545 self.assertEquals(callcount2[0], 3)
546
547 @defer.inlineCallbacks
548 def test_double_get(self):
549 callcount = [0]
550 callcount2 = [0]
551
552 class A:
553 @cached()
554 def func(self, key):
555 callcount[0] += 1
556 return key
557
558 @cached(cache_context=True)
559 def func2(self, key, cache_context):
560 callcount2[0] += 1
561 return self.func(key, on_invalidate=cache_context.invalidate)
562
563 a = A()
564 a.func2.cache.cache = mock.Mock(wraps=a.func2.cache.cache)
565
566 yield a.func2("foo")
567
568 self.assertEquals(callcount[0], 1)
569 self.assertEquals(callcount2[0], 1)
570
571 a.func2.invalidate(("foo",))
572 self.assertEquals(a.func2.cache.cache.pop.call_count, 1)
573
574 yield a.func2("foo")
575 a.func2.invalidate(("foo",))
576 self.assertEquals(a.func2.cache.cache.pop.call_count, 2)
577
578 self.assertEquals(callcount[0], 1)
579 self.assertEquals(callcount2[0], 2)
580
581 a.func.invalidate(("foo",))
582 self.assertEquals(a.func2.cache.cache.pop.call_count, 3)
583 yield a.func("foo")
584
585 self.assertEquals(callcount[0], 2)
586 self.assertEquals(callcount2[0], 2)
587
588 yield a.func2("foo")
589
590 self.assertEquals(callcount[0], 2)
591 self.assertEquals(callcount2[0], 3)
592
593
357594 class CachedListDescriptorTestCase(unittest.TestCase):
358595 @defer.inlineCallbacks
359596 def test_cache(self):
1818 from synapse.util.caches.lrucache import LruCache
1919 from synapse.util.caches.treecache import TreeCache
2020
21 from .. import unittest
21 from tests import unittest
22 from tests.unittest import override_config
2223
2324
2425 class LruCacheTestCase(unittest.HomeserverTestCase):
5859 self.assertEquals(cache.pop("key"), None)
5960
6061 def test_del_multi(self):
61 cache = LruCache(4, 2, cache_type=TreeCache)
62 cache = LruCache(4, keylen=2, cache_type=TreeCache)
6263 cache[("animal", "cat")] = "mew"
6364 cache[("animal", "dog")] = "woof"
6465 cache[("vehicles", "car")] = "vroom"
8283 cache.clear()
8384 self.assertEquals(len(cache), 0)
8485
86 @override_config({"caches": {"per_cache_factors": {"mycache": 10}}})
87 def test_special_size(self):
88 cache = LruCache(10, "mycache")
89 self.assertEqual(cache.max_size, 100)
90
8591
8692 class LruCacheCallbacksTestCase(unittest.HomeserverTestCase):
8793 def test_get(self):
159165 m2 = Mock()
160166 m3 = Mock()
161167 m4 = Mock()
162 cache = LruCache(4, 2, cache_type=TreeCache)
168 cache = LruCache(4, keylen=2, cache_type=TreeCache)
163169
164170 cache.set(("a", "1"), "value", callbacks=[m1])
165171 cache.set(("a", "2"), "value", callbacks=[m2])
2020 import uuid
2121 import warnings
2222 from inspect import getcallargs
23 from typing import Type
2324 from urllib import parse as urlparse
2425
2526 from mock import Mock, patch
3738 from synapse.logging.context import current_context, set_current_context
3839 from synapse.server import HomeServer
3940 from synapse.storage import DataStore
41 from synapse.storage.database import LoggingDatabaseConnection
4042 from synapse.storage.engines import PostgresEngine, create_engine
4143 from synapse.storage.prepare_database import prepare_database
4244 from synapse.util.ratelimitutils import FederationRateLimiter
8789 host=POSTGRES_HOST,
8890 password=POSTGRES_PASSWORD,
8991 )
92 db_conn = LoggingDatabaseConnection(db_conn, db_engine, "tests")
9093 prepare_database(db_conn, db_engine, None)
9194 db_conn.close()
9295
189192 def setup_test_homeserver(
190193 cleanup_func,
191194 name="test",
192 datastore=None,
193195 config=None,
194196 reactor=None,
195 homeserverToUse=TestHomeServer,
196 **kargs
197 homeserver_to_use: Type[HomeServer] = TestHomeServer,
198 **kwargs
197199 ):
198200 """
199201 Setup a homeserver suitable for running tests against. Keyword arguments
216218
217219 config.ldap_enabled = False
218220
219 if "clock" not in kargs:
220 kargs["clock"] = MockClock()
221 if "clock" not in kwargs:
222 kwargs["clock"] = MockClock()
221223
222224 if USE_POSTGRES_FOR_TESTS:
223225 test_db = "synapse_test_%s" % uuid.uuid4().hex
246248
247249 # Create the database before we actually try and connect to it, based off
248250 # the template database we generate in setupdb()
249 if datastore is None and isinstance(db_engine, PostgresEngine):
251 if isinstance(db_engine, PostgresEngine):
250252 db_conn = db_engine.module.connect(
251253 database=POSTGRES_BASE_DB,
252254 user=POSTGRES_USER,
262264 cur.close()
263265 db_conn.close()
264266
265 if datastore is None:
266 hs = homeserverToUse(
267 name,
268 config=config,
269 version_string="Synapse/tests",
270 tls_server_context_factory=Mock(),
271 tls_client_options_factory=Mock(),
272 reactor=reactor,
273 **kargs
274 )
275
276 hs.setup()
277 if homeserverToUse.__name__ == "TestHomeServer":
278 hs.setup_master()
279
280 if isinstance(db_engine, PostgresEngine):
281 database = hs.get_datastores().databases[0]
282
283 # We need to do cleanup on PostgreSQL
284 def cleanup():
285 import psycopg2
286
287 # Close all the db pools
288 database._db_pool.close()
289
290 dropped = False
291
292 # Drop the test database
293 db_conn = db_engine.module.connect(
294 database=POSTGRES_BASE_DB,
295 user=POSTGRES_USER,
296 host=POSTGRES_HOST,
297 password=POSTGRES_PASSWORD,
298 )
299 db_conn.autocommit = True
300 cur = db_conn.cursor()
301
302 # Try a few times to drop the DB. Some things may hold on to the
303 # database for a few more seconds due to flakiness, preventing
304 # us from dropping it when the test is over. If we can't drop
305 # it, warn and move on.
306 for x in range(5):
307 try:
308 cur.execute("DROP DATABASE IF EXISTS %s;" % (test_db,))
309 db_conn.commit()
310 dropped = True
311 except psycopg2.OperationalError as e:
312 warnings.warn(
313 "Couldn't drop old db: " + str(e), category=UserWarning
314 )
315 time.sleep(0.5)
316
317 cur.close()
318 db_conn.close()
319
320 if not dropped:
321 warnings.warn("Failed to drop old DB.", category=UserWarning)
322
323 if not LEAVE_DB:
324 # Register the cleanup hook
325 cleanup_func(cleanup)
326
327 else:
328 hs = homeserverToUse(
329 name,
330 datastore=datastore,
331 config=config,
332 version_string="Synapse/tests",
333 tls_server_context_factory=Mock(),
334 tls_client_options_factory=Mock(),
335 reactor=reactor,
336 **kargs
337 )
267 hs = homeserver_to_use(
268 name, config=config, version_string="Synapse/tests", reactor=reactor,
269 )
270
271 # Install @cache_in_self attributes
272 for key, val in kwargs.items():
273 setattr(hs, key, val)
274
275 # Mock TLS
276 hs.tls_server_context_factory = Mock()
277 hs.tls_client_options_factory = Mock()
278
279 hs.setup()
280 if homeserver_to_use == TestHomeServer:
281 hs.setup_background_tasks()
282
283 if isinstance(db_engine, PostgresEngine):
284 database = hs.get_datastores().databases[0]
285
286 # We need to do cleanup on PostgreSQL
287 def cleanup():
288 import psycopg2
289
290 # Close all the db pools
291 database._db_pool.close()
292
293 dropped = False
294
295 # Drop the test database
296 db_conn = db_engine.module.connect(
297 database=POSTGRES_BASE_DB,
298 user=POSTGRES_USER,
299 host=POSTGRES_HOST,
300 password=POSTGRES_PASSWORD,
301 )
302 db_conn.autocommit = True
303 cur = db_conn.cursor()
304
305 # Try a few times to drop the DB. Some things may hold on to the
306 # database for a few more seconds due to flakiness, preventing
307 # us from dropping it when the test is over. If we can't drop
308 # it, warn and move on.
309 for x in range(5):
310 try:
311 cur.execute("DROP DATABASE IF EXISTS %s;" % (test_db,))
312 db_conn.commit()
313 dropped = True
314 except psycopg2.OperationalError as e:
315 warnings.warn(
316 "Couldn't drop old db: " + str(e), category=UserWarning
317 )
318 time.sleep(0.5)
319
320 cur.close()
321 db_conn.close()
322
323 if not dropped:
324 warnings.warn("Failed to drop old DB.", category=UserWarning)
325
326 if not LEAVE_DB:
327 # Register the cleanup hook
328 cleanup_func(cleanup)
338329
339330 # bcrypt is far too slow to be doing in unit tests
340331 # Need to let the HS build an auth handler and then mess with it
350341
351342 hs.get_auth_handler().validate_hash = validate_hash
352343
353 fed = kargs.get("resource_for_federation", None)
344 fed = kwargs.get("resource_for_federation", None)
354345 if fed:
355346 register_federation_servlets(hs, fed)
356347
157157 coverage html
158158
159159 [testenv:mypy]
160 skip_install = True
161160 deps =
162161 {[base]deps}
163 mypy==0.782
164 mypy-zope
165 extras = all
162 extras = all,mypy
166163 commands = mypy
167164
168165 # To find all folders that pass mypy you run: